用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

197 lines
8.1KB

  1. """Custom losses."""
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.autograd import Variable
  6. __all__ = ['MixSoftmaxCrossEntropyLoss', 'MixSoftmaxCrossEntropyOHEMLoss',
  7. 'EncNetLoss', 'ICNetLoss', 'get_segmentation_loss']
  8. # TODO: optim function
  9. class MixSoftmaxCrossEntropyLoss(nn.CrossEntropyLoss):
  10. def __init__(self, aux=True, aux_weight=0.2, ignore_index=-1, **kwargs):
  11. super(MixSoftmaxCrossEntropyLoss, self).__init__(ignore_index=ignore_index)
  12. self.aux = aux
  13. self.aux_weight = aux_weight
  14. def _aux_forward(self, *inputs, **kwargs):
  15. *preds, target = tuple(inputs)
  16. loss = super(MixSoftmaxCrossEntropyLoss, self).forward(preds[0], target)
  17. for i in range(1, len(preds)):
  18. aux_loss = super(MixSoftmaxCrossEntropyLoss, self).forward(preds[i], target)
  19. loss += self.aux_weight * aux_loss
  20. return loss
  21. def forward(self, *inputs, **kwargs):
  22. preds, target = tuple(inputs)
  23. inputs = tuple(list(preds) + [target])
  24. if self.aux:
  25. return dict(loss=self._aux_forward(*inputs))
  26. else:
  27. return dict(loss=super(MixSoftmaxCrossEntropyLoss, self).forward(*inputs))
  28. # reference: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/nn/loss.py
  29. class EncNetLoss(nn.CrossEntropyLoss):
  30. """2D Cross Entropy Loss with SE Loss"""
  31. def __init__(self, se_loss=True, se_weight=0.2, nclass=19, aux=False,
  32. aux_weight=0.4, weight=None, ignore_index=-1, **kwargs):
  33. super(EncNetLoss, self).__init__(weight, None, ignore_index)
  34. self.se_loss = se_loss
  35. self.aux = aux
  36. self.nclass = nclass
  37. self.se_weight = se_weight
  38. self.aux_weight = aux_weight
  39. self.bceloss = nn.BCELoss(weight)
  40. def forward(self, *inputs):
  41. preds, target = tuple(inputs)
  42. inputs = tuple(list(preds) + [target])
  43. if not self.se_loss and not self.aux:
  44. return super(EncNetLoss, self).forward(*inputs)
  45. elif not self.se_loss:
  46. pred1, pred2, target = tuple(inputs)
  47. loss1 = super(EncNetLoss, self).forward(pred1, target)
  48. loss2 = super(EncNetLoss, self).forward(pred2, target)
  49. return dict(loss=loss1 + self.aux_weight * loss2)
  50. elif not self.aux:
  51. print (inputs)
  52. pred, se_pred, target = tuple(inputs)
  53. se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred)
  54. loss1 = super(EncNetLoss, self).forward(pred, target)
  55. loss2 = self.bceloss(torch.sigmoid(se_pred), se_target)
  56. return dict(loss=loss1 + self.se_weight * loss2)
  57. else:
  58. pred1, se_pred, pred2, target = tuple(inputs)
  59. se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred1)
  60. loss1 = super(EncNetLoss, self).forward(pred1, target)
  61. loss2 = super(EncNetLoss, self).forward(pred2, target)
  62. loss3 = self.bceloss(torch.sigmoid(se_pred), se_target)
  63. return dict(loss=loss1 + self.aux_weight * loss2 + self.se_weight * loss3)
  64. @staticmethod
  65. def _get_batch_label_vector(target, nclass):
  66. # target is a 3D Variable BxHxW, output is 2D BxnClass
  67. batch = target.size(0)
  68. tvect = Variable(torch.zeros(batch, nclass))
  69. for i in range(batch):
  70. hist = torch.histc(target[i].cpu().data.float(),
  71. bins=nclass, min=0,
  72. max=nclass - 1)
  73. vect = hist > 0
  74. tvect[i] = vect
  75. return tvect
  76. # TODO: optim function
  77. class ICNetLoss(nn.CrossEntropyLoss):
  78. """Cross Entropy Loss for ICNet"""
  79. def __init__(self, nclass, aux_weight=0.4, ignore_index=-1, **kwargs):
  80. super(ICNetLoss, self).__init__(ignore_index=ignore_index)
  81. self.nclass = nclass
  82. self.aux_weight = aux_weight
  83. def forward(self, *inputs):
  84. preds, target = tuple(inputs)
  85. inputs = tuple(list(preds) + [target])
  86. pred, pred_sub4, pred_sub8, pred_sub16, target = tuple(inputs)
  87. # [batch, W, H] -> [batch, 1, W, H]
  88. target = target.unsqueeze(1).float()
  89. target_sub4 = F.interpolate(target, pred_sub4.size()[2:], mode='bilinear', align_corners=True).squeeze(1).long()
  90. target_sub8 = F.interpolate(target, pred_sub8.size()[2:], mode='bilinear', align_corners=True).squeeze(1).long()
  91. target_sub16 = F.interpolate(target, pred_sub16.size()[2:], mode='bilinear', align_corners=True).squeeze(
  92. 1).long()
  93. loss1 = super(ICNetLoss, self).forward(pred_sub4, target_sub4)
  94. loss2 = super(ICNetLoss, self).forward(pred_sub8, target_sub8)
  95. loss3 = super(ICNetLoss, self).forward(pred_sub16, target_sub16)
  96. return dict(loss=loss1 + loss2 * self.aux_weight + loss3 * self.aux_weight)
  97. class OhemCrossEntropy2d(nn.Module):
  98. def __init__(self, ignore_index=-1, thresh=0.7, min_kept=100000, use_weight=True, **kwargs):
  99. super(OhemCrossEntropy2d, self).__init__()
  100. self.ignore_index = ignore_index
  101. self.thresh = float(thresh)
  102. self.min_kept = int(min_kept)
  103. if use_weight:
  104. weight = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754,
  105. 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955,
  106. 1.0865, 1.1529, 1.0507])
  107. self.criterion = torch.nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index)
  108. else:
  109. self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
  110. def forward(self, pred, target):
  111. n, c, h, w = pred.size()
  112. target = target.view(-1)
  113. valid_mask = target.ne(self.ignore_index)
  114. target = target * valid_mask.long()
  115. num_valid = valid_mask.sum()
  116. prob = F.softmax(pred, dim=1)
  117. prob = prob.transpose(0, 1).reshape(c, -1)
  118. if self.min_kept > num_valid:
  119. print("Lables: {}".format(num_valid))
  120. elif num_valid > 0:
  121. prob = prob.masked_fill_(1 - valid_mask, 1)
  122. mask_prob = prob[target, torch.arange(len(target), dtype=torch.long)]
  123. threshold = self.thresh
  124. if self.min_kept > 0:
  125. index = mask_prob.argsort()
  126. threshold_index = index[min(len(index), self.min_kept) - 1]
  127. if mask_prob[threshold_index] > self.thresh:
  128. threshold = mask_prob[threshold_index]
  129. kept_mask = mask_prob.le(threshold)
  130. valid_mask = valid_mask * kept_mask
  131. target = target * kept_mask.long()
  132. target = target.masked_fill_(1 - valid_mask, self.ignore_index)
  133. target = target.view(n, h, w)
  134. return self.criterion(pred, target)
  135. class MixSoftmaxCrossEntropyOHEMLoss(OhemCrossEntropy2d):
  136. def __init__(self, aux=False, aux_weight=0.4, weight=None, ignore_index=-1, **kwargs):
  137. super(MixSoftmaxCrossEntropyOHEMLoss, self).__init__(ignore_index=ignore_index)
  138. self.aux = aux
  139. self.aux_weight = aux_weight
  140. self.bceloss = nn.BCELoss(weight)
  141. def _aux_forward(self, *inputs, **kwargs):
  142. *preds, target = tuple(inputs)
  143. loss = super(MixSoftmaxCrossEntropyOHEMLoss, self).forward(preds[0], target)
  144. for i in range(1, len(preds)):
  145. aux_loss = super(MixSoftmaxCrossEntropyOHEMLoss, self).forward(preds[i], target)
  146. loss += self.aux_weight * aux_loss
  147. return loss
  148. def forward(self, *inputs):
  149. preds, target = tuple(inputs)
  150. inputs = tuple(list(preds) + [target])
  151. if self.aux:
  152. return dict(loss=self._aux_forward(*inputs))
  153. else:
  154. return dict(loss=super(MixSoftmaxCrossEntropyOHEMLoss, self).forward(*inputs))
  155. def get_segmentation_loss(model, use_ohem=False, **kwargs):
  156. if use_ohem:
  157. return MixSoftmaxCrossEntropyOHEMLoss(**kwargs)
  158. model = model.lower()
  159. if model == 'encnet':
  160. return EncNetLoss(**kwargs)
  161. elif model == 'icnet':
  162. return ICNetLoss(nclass=4, **kwargs)
  163. else:
  164. return MixSoftmaxCrossEntropyLoss(**kwargs)