车位角点检测代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 6.7KB

6 years ago
10 months ago
6 years ago
10 months ago
6 years ago
10 months ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
1 year ago
10 months ago
6 years ago
6 years ago
10 months ago
6 years ago
6 years ago
6 years ago
6 years ago
10 months ago
6 years ago
10 months ago
6 years ago
5 years ago
10 months ago
6 years ago
5 years ago
10 months ago
6 years ago
10 months ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. """Train directional marking point detector."""
  2. import math
  3. import random
  4. import numpy as np
  5. import torch
  6. import yaml
  7. from torch import nn
  8. from torch.utils.data import DataLoader
  9. import config
  10. import data
  11. import util
  12. from model import DirectionalPointDetector
  13. from models.yolo import Model
  14. # import os
  15. # os.environ['CUDA_VISIBLE_DEVICES'] = '1'
  16. def plot_prediction(logger, image, marking_points, prediction):
  17. """Plot the ground truth and prediction of a random sample in a batch."""
  18. rand_sample = random.randint(0, image.size(0)-1)
  19. sampled_image = util.tensor2im(image[rand_sample])
  20. logger.plot_marking_points(sampled_image, marking_points[rand_sample],
  21. win_name='gt_marking_points')
  22. sampled_image = util.tensor2im(image[rand_sample])
  23. pred_points = data.get_predicted_points(prediction[rand_sample], 0.01)
  24. if pred_points:
  25. logger.plot_marking_points(sampled_image,
  26. list(list(zip(*pred_points))[1]),
  27. win_name='pred_marking_points')
  28. def generate_objective(marking_points_batch, device):
  29. """Get regression objective and gradient for directional point detector."""
  30. batch_size = len(marking_points_batch)
  31. objective = torch.zeros(batch_size, config.NUM_FEATURE_MAP_CHANNEL,
  32. config.FEATURE_MAP_SIZE, config.FEATURE_MAP_SIZE,
  33. device=device)
  34. gradient = torch.zeros_like(objective)
  35. gradient[:, 0].fill_(1.)
  36. for batch_idx, marking_points in enumerate(marking_points_batch):
  37. for marking_point in marking_points:
  38. col = math.floor(marking_point.x * config.FEATURE_MAP_SIZE)
  39. row = math.floor(marking_point.y * config.FEATURE_MAP_SIZE)
  40. # Confidence Regression
  41. objective[batch_idx, 0, row, col] = 1.
  42. # Makring Point Shape Regression
  43. objective[batch_idx, 1, row, col] = marking_point.shape
  44. # Offset Regression
  45. objective[batch_idx, 2, row, col] = marking_point.x*config.FEATURE_MAP_SIZE - col
  46. objective[batch_idx, 3, row, col] = marking_point.y*config.FEATURE_MAP_SIZE - row
  47. # Direction Regression
  48. direction = marking_point.direction
  49. objective[batch_idx, 4, row, col] = math.cos(direction)
  50. objective[batch_idx, 5, row, col] = math.sin(direction)
  51. # Assign Gradient
  52. gradient[batch_idx, 1:6, row, col].fill_(1.)
  53. return objective, gradient
  54. # class FocalLoss(nn.Module):
  55. # # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  56. # def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  57. # super(FocalLoss, self).__init__()
  58. # self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  59. # self.gamma = gamma
  60. # self.alpha = alpha
  61. # self.reduction = loss_fcn.reduction
  62. # self.loss_fcn.reduction = 'none' # required to apply FL to each element
  63. #
  64. # def forward(self, pred, true):
  65. # loss = self.loss_fcn(pred, true)
  66. # # p_t = torch.exp(-loss)
  67. # # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
  68. #
  69. # # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
  70. # pred_prob = torch.sigmoid(pred) # prob from logits
  71. # p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
  72. # alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  73. # modulating_factor = (1.0 - p_t) ** self.gamma
  74. # loss *= alpha_factor * modulating_factor
  75. #
  76. # if self.reduction == 'mean':
  77. # return loss.mean()
  78. # elif self.reduction == 'sum':
  79. # return loss.sum()
  80. # else: # 'none'
  81. # return loss
  82. def train_detector(args):
  83. """Train directional point detector."""
  84. args.cuda = not args.disable_cuda and torch.cuda.is_available()
  85. device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
  86. torch.set_grad_enabled(True)
  87. # dp_detector = DirectionalPointDetector(
  88. # 3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
  89. # if args.detector_weights:
  90. # print("Loading weights: %s" % args.detector_weights)
  91. # dp_detector.load_state_dict(torch.load(args.detector_weights))
  92. # dp_detector.train()
  93. with open(args.hyp) as f:
  94. hyp = yaml.load(f, Loader=yaml.SafeLoader)
  95. dp_detector = Model(args.cfg, ch=3, anchors=hyp.get('anchors')).to(device)
  96. if args.detector_weights:
  97. print("Loading weights: %s" % args.detector_weights)
  98. dp_detector.load_state_dict(torch.load(args.detector_weights))
  99. dp_detector.train()
  100. optimizer = torch.optim.Adam(dp_detector.parameters(), lr=args.lr)
  101. if args.optimizer_weights:
  102. print("Loading weights: %s" % args.optimizer_weights)
  103. optimizer.load_state_dict(torch.load(args.optimizer_weights))
  104. logger = util.Logger(args.enable_visdom, ['train_loss'])
  105. data_loader = DataLoader(data.ParkingSlotDataset(args.dataset_directory),
  106. batch_size=args.batch_size, shuffle=True,
  107. num_workers=args.data_loading_workers,
  108. pin_memory=True,
  109. collate_fn=lambda x: list(zip(*x)))
  110. # BCEobj = nn.BCEWithLogitsLoss(reduction='none', pos_weight=torch.tensor([hyp['obj_pw']], device=device))
  111. # # Focal loss
  112. # g = hyp['fl_gamma'] # focal loss gamma
  113. # if g > 0:
  114. # BCEobj = FocalLoss(BCEobj, g)
  115. for epoch_idx in range(args.num_epochs):
  116. for iter_idx, (images, marking_points) in enumerate(data_loader):
  117. images = torch.stack(images).to(device)
  118. # images = torch.from_numpy(np.stack(images, axis=0)).to(device).permute(0, 3, 1, 2)
  119. optimizer.zero_grad()
  120. prediction = dp_detector(images)
  121. objective, gradient = generate_objective(marking_points, device)
  122. # lobj = BCEobj(prediction[:, 0, ...], objective[:, 0, ...])
  123. loss = (prediction - objective) ** 2
  124. # lobj = torch.unsqueeze(lobj, 1)
  125. # loss = torch.cat((lobj, l_sxycs), 1)
  126. loss.backward(gradient)
  127. optimizer.step()
  128. train_loss = torch.sum(loss*gradient).item() / loss.size(0)
  129. logger.log(epoch=epoch_idx, iter=iter_idx, train_loss=train_loss)
  130. if args.enable_visdom:
  131. plot_prediction(logger, images, marking_points, prediction)
  132. torch.save(dp_detector.state_dict(),
  133. 'weights/dp_detector_%d.pth' % epoch_idx)
  134. torch.save(optimizer.state_dict(), 'weights/optimizer.pth')
  135. if __name__ == '__main__':
  136. train_detector(config.get_parser_for_training().parse_args())