import torch import torch.nn as nn import torch.nn.functional as F class BCELoss(nn.Module): def __init__(self): super(BCELoss, self).__init__() def _gather_feat(self, feat, ind, mask=None): dim = feat.size(2) ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) feat = feat.gather(1, ind) if mask is not None: mask = mask.unsqueeze(2).expand_as(feat) feat = feat[mask] feat = feat.view(-1, dim) return feat def _tranpose_and_gather_feat(self, feat, ind): feat = feat.permute(0, 2, 3, 1).contiguous() feat = feat.view(feat.size(0), -1, feat.size(3)) feat = self._gather_feat(feat, ind) return feat def forward(self, output, mask, ind, target): # torch.Size([1, 1, 152, 152]) # torch.Size([1, 500]) # torch.Size([1, 500]) # torch.Size([1, 500, 1]) pred = self._tranpose_and_gather_feat(output, ind) # torch.Size([1, 500, 1]) if mask.sum(): mask = mask.unsqueeze(2).expand_as(pred).bool() loss = F.binary_cross_entropy(pred.masked_select(mask), target.masked_select(mask), reduction='mean') return loss else: return 0. class OffSmoothL1Loss(nn.Module): def __init__(self): super(OffSmoothL1Loss, self).__init__() def _gather_feat(self, feat, ind, mask=None): dim = feat.size(2) ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) feat = feat.gather(1, ind) if mask is not None: mask = mask.unsqueeze(2).expand_as(feat) feat = feat[mask] feat = feat.view(-1, dim) return feat def _tranpose_and_gather_feat(self, feat, ind): feat = feat.permute(0, 2, 3, 1).contiguous() feat = feat.view(feat.size(0), -1, feat.size(3)) feat = self._gather_feat(feat, ind) return feat def forward(self, output, mask, ind, target): # torch.Size([1, 2, 152, 152]) # torch.Size([1, 500]) # torch.Size([1, 500]) # torch.Size([1, 500, 2]) pred = self._tranpose_and_gather_feat(output, ind) # torch.Size([1, 500, 2]) if mask.sum(): mask = mask.unsqueeze(2).expand_as(pred).bool() loss = F.smooth_l1_loss(pred.masked_select(mask), target.masked_select(mask), reduction='mean') return loss else: return 0. class FocalLoss(nn.Module): def __init__(self): super(FocalLoss, self).__init__() def forward(self, pred, gt): pos_inds = gt.eq(1).float() neg_inds = gt.lt(1).float() neg_weights = torch.pow(1 - gt, 4) loss = 0 pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds num_pos = pos_inds.float().sum() pos_loss = pos_loss.sum() neg_loss = neg_loss.sum() if num_pos == 0: loss = loss - neg_loss else: loss = loss - (pos_loss + neg_loss) / num_pos return loss def isnan(x): return x != x class LossAll(torch.nn.Module): def __init__(self): super(LossAll, self).__init__() self.L_hm = FocalLoss() self.L_wh = OffSmoothL1Loss() self.L_off = OffSmoothL1Loss() self.L_cls_theta = BCELoss() def forward(self, pr_decs, gt_batch): hm_loss = self.L_hm(pr_decs['hm'], gt_batch['hm']) wh_loss = self.L_wh(pr_decs['wh'], gt_batch['reg_mask'], gt_batch['ind'], gt_batch['wh']) off_loss = self.L_off(pr_decs['reg'], gt_batch['reg_mask'], gt_batch['ind'], gt_batch['reg']) ## add cls_theta_loss = self.L_cls_theta(pr_decs['cls_theta'], gt_batch['reg_mask'], gt_batch['ind'], gt_batch['cls_theta']) if isnan(hm_loss) or isnan(wh_loss) or isnan(off_loss): print('hm loss is {}'.format(hm_loss)) print('wh loss is {}'.format(wh_loss)) print('off loss is {}'.format(off_loss)) # print(hm_loss) # print(wh_loss) # print(off_loss) # print(cls_theta_loss) # print('-----------------') loss = hm_loss + wh_loss + off_loss + cls_theta_loss return loss