|
- """Train directional marking point detector."""
- import math
- import random
-
- import numpy as np
- import torch
- import yaml
- from torch import nn
- from torch.utils.data import DataLoader
- import config
- import data
- import util
- from model import DirectionalPointDetector
- from models.yolo import Model
-
- # import os
- # os.environ['CUDA_VISIBLE_DEVICES'] = '1'
-
-
- def plot_prediction(logger, image, marking_points, prediction):
- """Plot the ground truth and prediction of a random sample in a batch."""
- rand_sample = random.randint(0, image.size(0)-1)
- sampled_image = util.tensor2im(image[rand_sample])
- logger.plot_marking_points(sampled_image, marking_points[rand_sample],
- win_name='gt_marking_points')
- sampled_image = util.tensor2im(image[rand_sample])
- pred_points = data.get_predicted_points(prediction[rand_sample], 0.01)
- if pred_points:
- logger.plot_marking_points(sampled_image,
- list(list(zip(*pred_points))[1]),
- win_name='pred_marking_points')
-
-
- def generate_objective(marking_points_batch, device):
- """Get regression objective and gradient for directional point detector."""
- batch_size = len(marking_points_batch)
- objective = torch.zeros(batch_size, config.NUM_FEATURE_MAP_CHANNEL,
- config.FEATURE_MAP_SIZE, config.FEATURE_MAP_SIZE,
- device=device)
- gradient = torch.zeros_like(objective)
- gradient[:, 0].fill_(1.)
- for batch_idx, marking_points in enumerate(marking_points_batch):
- for marking_point in marking_points:
- col = math.floor(marking_point.x * config.FEATURE_MAP_SIZE)
- row = math.floor(marking_point.y * config.FEATURE_MAP_SIZE)
- # Confidence Regression
- objective[batch_idx, 0, row, col] = 1.
- # Makring Point Shape Regression
- objective[batch_idx, 1, row, col] = marking_point.shape
- # Offset Regression
- objective[batch_idx, 2, row, col] = marking_point.x*config.FEATURE_MAP_SIZE - col
- objective[batch_idx, 3, row, col] = marking_point.y*config.FEATURE_MAP_SIZE - row
- # Direction Regression
- direction = marking_point.direction
- objective[batch_idx, 4, row, col] = math.cos(direction)
- objective[batch_idx, 5, row, col] = math.sin(direction)
- # Assign Gradient
- gradient[batch_idx, 1:6, row, col].fill_(1.)
- return objective, gradient
-
-
- # class FocalLoss(nn.Module):
- # # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
- # def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
- # super(FocalLoss, self).__init__()
- # self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
- # self.gamma = gamma
- # self.alpha = alpha
- # self.reduction = loss_fcn.reduction
- # self.loss_fcn.reduction = 'none' # required to apply FL to each element
- #
- # def forward(self, pred, true):
- # loss = self.loss_fcn(pred, true)
- # # p_t = torch.exp(-loss)
- # # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
- #
- # # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
- # pred_prob = torch.sigmoid(pred) # prob from logits
- # p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
- # alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
- # modulating_factor = (1.0 - p_t) ** self.gamma
- # loss *= alpha_factor * modulating_factor
- #
- # if self.reduction == 'mean':
- # return loss.mean()
- # elif self.reduction == 'sum':
- # return loss.sum()
- # else: # 'none'
- # return loss
-
-
- def train_detector(args):
- """Train directional point detector."""
- args.cuda = not args.disable_cuda and torch.cuda.is_available()
- device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
- torch.set_grad_enabled(True)
-
- # dp_detector = DirectionalPointDetector(
- # 3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
- # if args.detector_weights:
- # print("Loading weights: %s" % args.detector_weights)
- # dp_detector.load_state_dict(torch.load(args.detector_weights))
- # dp_detector.train()
-
- with open(args.hyp) as f:
- hyp = yaml.load(f, Loader=yaml.SafeLoader)
- dp_detector = Model(args.cfg, ch=3, anchors=hyp.get('anchors')).to(device)
- if args.detector_weights:
- print("Loading weights: %s" % args.detector_weights)
- dp_detector.load_state_dict(torch.load(args.detector_weights))
- dp_detector.train()
- optimizer = torch.optim.Adam(dp_detector.parameters(), lr=args.lr)
- if args.optimizer_weights:
- print("Loading weights: %s" % args.optimizer_weights)
- optimizer.load_state_dict(torch.load(args.optimizer_weights))
-
- logger = util.Logger(args.enable_visdom, ['train_loss'])
- data_loader = DataLoader(data.ParkingSlotDataset(args.dataset_directory),
- batch_size=args.batch_size, shuffle=True,
- num_workers=args.data_loading_workers,
- pin_memory=True,
- collate_fn=lambda x: list(zip(*x)))
-
- # BCEobj = nn.BCEWithLogitsLoss(reduction='none', pos_weight=torch.tensor([hyp['obj_pw']], device=device))
-
- # # Focal loss
- # g = hyp['fl_gamma'] # focal loss gamma
- # if g > 0:
- # BCEobj = FocalLoss(BCEobj, g)
-
- for epoch_idx in range(args.num_epochs):
- for iter_idx, (images, marking_points) in enumerate(data_loader):
- images = torch.stack(images).to(device)
- # images = torch.from_numpy(np.stack(images, axis=0)).to(device).permute(0, 3, 1, 2)
- optimizer.zero_grad()
- prediction = dp_detector(images)
- objective, gradient = generate_objective(marking_points, device)
- # lobj = BCEobj(prediction[:, 0, ...], objective[:, 0, ...])
- loss = (prediction - objective) ** 2
- # lobj = torch.unsqueeze(lobj, 1)
- # loss = torch.cat((lobj, l_sxycs), 1)
- loss.backward(gradient)
- optimizer.step()
-
- train_loss = torch.sum(loss*gradient).item() / loss.size(0)
- logger.log(epoch=epoch_idx, iter=iter_idx, train_loss=train_loss)
- if args.enable_visdom:
- plot_prediction(logger, images, marking_points, prediction)
- torch.save(dp_detector.state_dict(),
- 'weights/dp_detector_%d.pth' % epoch_idx)
- torch.save(optimizer.state_dict(), 'weights/optimizer.pth')
-
-
- if __name__ == '__main__':
- train_detector(config.get_parser_for_training().parse_args())
|