车位角点检测代码
Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

156 lines
6.7KB

  1. """Train directional marking point detector."""
  2. import math
  3. import random
  4. import numpy as np
  5. import torch
  6. import yaml
  7. from torch import nn
  8. from torch.utils.data import DataLoader
  9. import config
  10. import data
  11. import util
  12. from model import DirectionalPointDetector
  13. from models.yolo import Model
  14. # import os
  15. # os.environ['CUDA_VISIBLE_DEVICES'] = '1'
  16. def plot_prediction(logger, image, marking_points, prediction):
  17. """Plot the ground truth and prediction of a random sample in a batch."""
  18. rand_sample = random.randint(0, image.size(0)-1)
  19. sampled_image = util.tensor2im(image[rand_sample])
  20. logger.plot_marking_points(sampled_image, marking_points[rand_sample],
  21. win_name='gt_marking_points')
  22. sampled_image = util.tensor2im(image[rand_sample])
  23. pred_points = data.get_predicted_points(prediction[rand_sample], 0.01)
  24. if pred_points:
  25. logger.plot_marking_points(sampled_image,
  26. list(list(zip(*pred_points))[1]),
  27. win_name='pred_marking_points')
  28. def generate_objective(marking_points_batch, device):
  29. """Get regression objective and gradient for directional point detector."""
  30. batch_size = len(marking_points_batch)
  31. objective = torch.zeros(batch_size, config.NUM_FEATURE_MAP_CHANNEL,
  32. config.FEATURE_MAP_SIZE, config.FEATURE_MAP_SIZE,
  33. device=device)
  34. gradient = torch.zeros_like(objective)
  35. gradient[:, 0].fill_(1.)
  36. for batch_idx, marking_points in enumerate(marking_points_batch):
  37. for marking_point in marking_points:
  38. col = math.floor(marking_point.x * config.FEATURE_MAP_SIZE)
  39. row = math.floor(marking_point.y * config.FEATURE_MAP_SIZE)
  40. # Confidence Regression
  41. objective[batch_idx, 0, row, col] = 1.
  42. # Makring Point Shape Regression
  43. objective[batch_idx, 1, row, col] = marking_point.shape
  44. # Offset Regression
  45. objective[batch_idx, 2, row, col] = marking_point.x*config.FEATURE_MAP_SIZE - col
  46. objective[batch_idx, 3, row, col] = marking_point.y*config.FEATURE_MAP_SIZE - row
  47. # Direction Regression
  48. direction = marking_point.direction
  49. objective[batch_idx, 4, row, col] = math.cos(direction)
  50. objective[batch_idx, 5, row, col] = math.sin(direction)
  51. # Assign Gradient
  52. gradient[batch_idx, 1:6, row, col].fill_(1.)
  53. return objective, gradient
  54. # class FocalLoss(nn.Module):
  55. # # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  56. # def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  57. # super(FocalLoss, self).__init__()
  58. # self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  59. # self.gamma = gamma
  60. # self.alpha = alpha
  61. # self.reduction = loss_fcn.reduction
  62. # self.loss_fcn.reduction = 'none' # required to apply FL to each element
  63. #
  64. # def forward(self, pred, true):
  65. # loss = self.loss_fcn(pred, true)
  66. # # p_t = torch.exp(-loss)
  67. # # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
  68. #
  69. # # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
  70. # pred_prob = torch.sigmoid(pred) # prob from logits
  71. # p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
  72. # alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  73. # modulating_factor = (1.0 - p_t) ** self.gamma
  74. # loss *= alpha_factor * modulating_factor
  75. #
  76. # if self.reduction == 'mean':
  77. # return loss.mean()
  78. # elif self.reduction == 'sum':
  79. # return loss.sum()
  80. # else: # 'none'
  81. # return loss
  82. def train_detector(args):
  83. """Train directional point detector."""
  84. args.cuda = not args.disable_cuda and torch.cuda.is_available()
  85. device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
  86. torch.set_grad_enabled(True)
  87. # dp_detector = DirectionalPointDetector(
  88. # 3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
  89. # if args.detector_weights:
  90. # print("Loading weights: %s" % args.detector_weights)
  91. # dp_detector.load_state_dict(torch.load(args.detector_weights))
  92. # dp_detector.train()
  93. with open(args.hyp) as f:
  94. hyp = yaml.load(f, Loader=yaml.SafeLoader)
  95. dp_detector = Model(args.cfg, ch=3, anchors=hyp.get('anchors')).to(device)
  96. if args.detector_weights:
  97. print("Loading weights: %s" % args.detector_weights)
  98. dp_detector.load_state_dict(torch.load(args.detector_weights))
  99. dp_detector.train()
  100. optimizer = torch.optim.Adam(dp_detector.parameters(), lr=args.lr)
  101. if args.optimizer_weights:
  102. print("Loading weights: %s" % args.optimizer_weights)
  103. optimizer.load_state_dict(torch.load(args.optimizer_weights))
  104. logger = util.Logger(args.enable_visdom, ['train_loss'])
  105. data_loader = DataLoader(data.ParkingSlotDataset(args.dataset_directory),
  106. batch_size=args.batch_size, shuffle=True,
  107. num_workers=args.data_loading_workers,
  108. pin_memory=True,
  109. collate_fn=lambda x: list(zip(*x)))
  110. # BCEobj = nn.BCEWithLogitsLoss(reduction='none', pos_weight=torch.tensor([hyp['obj_pw']], device=device))
  111. # # Focal loss
  112. # g = hyp['fl_gamma'] # focal loss gamma
  113. # if g > 0:
  114. # BCEobj = FocalLoss(BCEobj, g)
  115. for epoch_idx in range(args.num_epochs):
  116. for iter_idx, (images, marking_points) in enumerate(data_loader):
  117. images = torch.stack(images).to(device)
  118. # images = torch.from_numpy(np.stack(images, axis=0)).to(device).permute(0, 3, 1, 2)
  119. optimizer.zero_grad()
  120. prediction = dp_detector(images)
  121. objective, gradient = generate_objective(marking_points, device)
  122. # lobj = BCEobj(prediction[:, 0, ...], objective[:, 0, ...])
  123. loss = (prediction - objective) ** 2
  124. # lobj = torch.unsqueeze(lobj, 1)
  125. # loss = torch.cat((lobj, l_sxycs), 1)
  126. loss.backward(gradient)
  127. optimizer.step()
  128. train_loss = torch.sum(loss*gradient).item() / loss.size(0)
  129. logger.log(epoch=epoch_idx, iter=iter_idx, train_loss=train_loss)
  130. if args.enable_visdom:
  131. plot_prediction(logger, images, marking_points, prediction)
  132. torch.save(dp_detector.state_dict(),
  133. 'weights/dp_detector_%d.pth' % epoch_idx)
  134. torch.save(optimizer.state_dict(), 'weights/optimizer.pth')
  135. if __name__ == '__main__':
  136. train_detector(config.get_parser_for_training().parse_args())