import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.box_utils import match, log_sum_exp from data import cfg_mnet GPU = cfg_mnet['gpu_train'] class MultiBoxLoss(nn.Module): """SSD Weighted Loss Function Compute Targets: 1) Produce Confidence Target Indices by matching ground truth boxes with (default) 'priorboxes' that have jaccard index > threshold parameter (default threshold: 0.5). 2) Produce localization target by 'encoding' variance into offsets of ground truth boxes and their matched 'priorboxes'. 3) Hard negative mining to filter the excessive number of negative examples that comes with using a large number of default bounding boxes. (default negative:positive ratio 3:1) Objective Loss: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss weighted by α which is set to 1 by cross val. Args: c: class confidences, l: predicted boxes, g: ground truth boxes N: number of matched default boxes See: https://arxiv.org/pdf/1512.02325.pdf for more details. """ def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target): super(MultiBoxLoss, self).__init__() self.num_classes = num_classes self.threshold = overlap_thresh self.background_label = bkg_label self.encode_target = encode_target self.use_prior_for_matching = prior_for_matching self.do_neg_mining = neg_mining self.negpos_ratio = neg_pos self.neg_overlap = neg_overlap self.variance = [0.1, 0.2] def forward(self, predictions, priors, targets): """Multibox Loss Args: predictions (tuple): A tuple containing loc preds, conf preds, and prior boxes from SSD net. conf shape: torch.size(batch_size,num_priors,num_classes) loc shape: torch.size(batch_size,num_priors,4) priors shape: torch.size(num_priors,4) ground_truth (tensor): Ground truth boxes and labels for a batch, shape: [batch_size,num_objs,5] (last idx is the label). """ loc_data, conf_data, landm_data = predictions # predictions是嵌套,loc_data是这里是2x16800x4,conf_data是2x16800x2,landm_data是2x16800x10 # landm_data这里是16800x2x10 priors = priors # priors这里是16800x4 num = loc_data.size(0) num_priors = (priors.size(0)) # match priors (default boxes) and ground truth boxes loc_t = torch.Tensor(num, num_priors, 4) # landm_t = torch.Tensor(num, num_priors, 10) landm_t = torch.Tensor(num, num_priors, 6) conf_t = torch.LongTensor(num, num_priors) for idx in range(num): truths = targets[idx][:, :4].data labels = targets[idx][:, -1].data # landms = targets[idx][:, 4:14].data landms = targets[idx][:, 4:10].data defaults = priors.data match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx) if GPU: loc_t = loc_t.cuda() conf_t = conf_t.cuda() landm_t = landm_t.cuda() zeros = torch.tensor(0).cuda() # landm Loss (Smooth L1) # Shape: [batch,num_priors,10] pos1 = conf_t > zeros num_pos_landm = pos1.long().sum(1, keepdim=True) N1 = max(num_pos_landm.data.sum().float(), 1) pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data) # landm_p = landm_data[pos_idx1].view(-1, 10) #-1是定参数,保证一致 landm_p = landm_data[pos_idx1].view(-1, 6) #-1是定参数,保证一致 # landm_t = landm_t[pos_idx1].view(-1, 10) landm_t = landm_t[pos_idx1].view(-1, 6) loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum') pos = conf_t != zeros conf_t[pos] = 1 # Localization Loss (Smooth L1) # Shape: [batch,num_priors,4] pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) loc_p = loc_data[pos_idx].view(-1, 4) loc_t = loc_t[pos_idx].view(-1, 4) loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum') # Compute max conf across batch for hard negative mining batch_conf = conf_data.view(-1, self.num_classes) loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) # Hard Negative Mining loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now loss_c = loss_c.view(num, -1) _, loss_idx = loss_c.sort(1, descending=True) _, idx_rank = loss_idx.sort(1) num_pos = pos.long().sum(1, keepdim=True) num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) neg = idx_rank < num_neg.expand_as(idx_rank) # Confidence Loss Including Positive and Negative Examples pos_idx = pos.unsqueeze(2).expand_as(conf_data) neg_idx = neg.unsqueeze(2).expand_as(conf_data) conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes) targets_weighted = conf_t[(pos+neg).gt(0)] loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum') # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N N = max(num_pos.data.sum().float(), 1) loss_l /= N loss_c /= N loss_landm /= N1 return loss_l, loss_c, loss_landm