车位角点检测代码
Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

vor 6 Jahren
vor 6 Jahren
vor 6 Jahren
vor 6 Jahren
vor 6 Jahren
vor 5 Jahren
vor 6 Jahren
vor 6 Jahren
vor 6 Jahren
vor 6 Jahren
vor 6 Jahren
vor 6 Jahren
vor 6 Jahren
vor 6 Jahren
vor 6 Jahren
vor 5 Jahren
vor 6 Jahren
vor 5 Jahren
vor 6 Jahren
vor 6 Jahren
vor 5 Jahren
vor 6 Jahren
vor 6 Jahren
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. """Train directional marking point detector."""
  2. import math
  3. import random
  4. import torch
  5. from torch.utils.data import DataLoader
  6. import config
  7. import data
  8. import util
  9. from model import DirectionalPointDetector
  10. def plot_prediction(logger, image, marking_points, prediction):
  11. """Plot the ground truth and prediction of a random sample in a batch."""
  12. rand_sample = random.randint(0, image.size(0)-1)
  13. sampled_image = util.tensor2im(image[rand_sample])
  14. logger.plot_marking_points(sampled_image, marking_points[rand_sample],
  15. win_name='gt_marking_points')
  16. sampled_image = util.tensor2im(image[rand_sample])
  17. pred_points = data.get_predicted_points(prediction[rand_sample], 0.01)
  18. if pred_points:
  19. logger.plot_marking_points(sampled_image,
  20. list(list(zip(*pred_points))[1]),
  21. win_name='pred_marking_points')
  22. def generate_objective(marking_points_batch, device):
  23. """Get regression objective and gradient for directional point detector."""
  24. batch_size = len(marking_points_batch)
  25. objective = torch.zeros(batch_size, config.NUM_FEATURE_MAP_CHANNEL,
  26. config.FEATURE_MAP_SIZE, config.FEATURE_MAP_SIZE,
  27. device=device)
  28. gradient = torch.zeros_like(objective)
  29. gradient[:, 0].fill_(1.)
  30. for batch_idx, marking_points in enumerate(marking_points_batch):
  31. for marking_point in marking_points:
  32. col = math.floor(marking_point.x * config.FEATURE_MAP_SIZE)
  33. row = math.floor(marking_point.y * config.FEATURE_MAP_SIZE)
  34. # Confidence Regression
  35. objective[batch_idx, 0, row, col] = 1.
  36. # Makring Point Shape Regression
  37. objective[batch_idx, 1, row, col] = marking_point.shape
  38. # Offset Regression
  39. objective[batch_idx, 2, row, col] = marking_point.x*16 - col
  40. objective[batch_idx, 3, row, col] = marking_point.y*16 - row
  41. # Direction Regression
  42. direction = marking_point.direction
  43. objective[batch_idx, 4, row, col] = math.cos(direction)
  44. objective[batch_idx, 5, row, col] = math.sin(direction)
  45. # Assign Gradient
  46. gradient[batch_idx, 1:6, row, col].fill_(1.)
  47. return objective, gradient
  48. def train_detector(args):
  49. """Train directional point detector."""
  50. args.cuda = not args.disable_cuda and torch.cuda.is_available()
  51. device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
  52. torch.set_grad_enabled(True)
  53. dp_detector = DirectionalPointDetector(
  54. 3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
  55. if args.detector_weights:
  56. print("Loading weights: %s" % args.detector_weights)
  57. dp_detector.load_state_dict(torch.load(args.detector_weights))
  58. dp_detector.train()
  59. optimizer = torch.optim.Adam(dp_detector.parameters(), lr=args.lr)
  60. if args.optimizer_weights:
  61. print("Loading weights: %s" % args.optimizer_weights)
  62. optimizer.load_state_dict(torch.load(args.optimizer_weights))
  63. logger = util.Logger(args.enable_visdom, ['train_loss'])
  64. data_loader = DataLoader(data.ParkingSlotDataset(args.dataset_directory),
  65. batch_size=args.batch_size, shuffle=True,
  66. num_workers=args.data_loading_workers,
  67. collate_fn=lambda x: list(zip(*x)))
  68. for epoch_idx in range(args.num_epochs):
  69. for iter_idx, (images, marking_points) in enumerate(data_loader):
  70. images = torch.stack(images).to(device)
  71. optimizer.zero_grad()
  72. prediction = dp_detector(images)
  73. objective, gradient = generate_objective(marking_points, device)
  74. loss = (prediction - objective) ** 2
  75. loss.backward(gradient)
  76. optimizer.step()
  77. train_loss = torch.sum(loss*gradient).item() / loss.size(0)
  78. logger.log(epoch=epoch_idx, iter=iter_idx, train_loss=train_loss)
  79. if args.enable_visdom:
  80. plot_prediction(logger, images, marking_points, prediction)
  81. torch.save(dp_detector.state_dict(),
  82. 'weights/dp_detector_%d.pth' % epoch_idx)
  83. torch.save(optimizer.state_dict(), 'weights/optimizer.pth')
  84. if __name__ == '__main__':
  85. train_detector(config.get_parser_for_training().parse_args())