|
- #!/usr/bin/python
- # -*- encoding: utf-8 -*-
-
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from loss.util import enet_weighing
- import numpy as np
-
-
- class OhemCELoss(nn.Module):
- def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
- super(OhemCELoss, self).__init__()
- self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()
- self.n_min = n_min
- self.ignore_lb = ignore_lb
-
- self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
-
- def forward(self, logits, labels):
- N, C, H, W = logits.size()
- loss = self.criteria(logits, labels).view(-1)
- loss, _ = torch.sort(loss, descending=True)
- if loss[self.n_min] > self.thresh:
- loss = loss[loss>self.thresh]
- else:
- loss = loss[:self.n_min]
- return torch.mean(loss)
-
- class WeightedOhemCELoss(nn.Module):
- def __init__(self, thresh, n_min, num_classes, ignore_lb=255, *args, **kwargs):
- super(WeightedOhemCELoss, self).__init__()
- self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()
- self.n_min = n_min
- self.ignore_lb = ignore_lb
- self.num_classes = num_classes
- # self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
-
- def forward(self, logits, labels):
- N, C, H, W = logits.size()
- criteria = nn.CrossEntropyLoss(weight=enet_weighing(labels, self.num_classes).cuda(), ignore_index=self.ignore_lb, reduction='none')
- loss = criteria(logits, labels).view(-1)
- loss, _ = torch.sort(loss, descending=True)
- if loss[self.n_min] > self.thresh:
- loss = loss[loss>self.thresh]
- else:
- loss = loss[:self.n_min]
- return torch.mean(loss)
-
- class SoftmaxFocalLoss(nn.Module):
- def __init__(self, gamma, ignore_lb=255, *args, **kwargs):
- super(FocalLoss, self).__init__()
- self.gamma = gamma
- self.nll = nn.NLLLoss(ignore_index=ignore_lb)
-
- def forward(self, logits, labels):
- scores = F.softmax(logits, dim=1)
- factor = torch.pow(1.-scores, self.gamma)
- log_score = F.log_softmax(logits, dim=1)
- log_score = factor * log_score
- loss = self.nll(log_score, labels)
- return loss
-
-
- if __name__ == '__main__':
- torch.manual_seed(15)
- criteria1 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda()
- criteria2 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda()
- net1 = nn.Sequential(
- nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1),
- )
- net1.cuda()
- net1.train()
- net2 = nn.Sequential(
- nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1),
- )
- net2.cuda()
- net2.train()
-
- with torch.no_grad():
- inten = torch.randn(16, 3, 20, 20).cuda()
- lbs = torch.randint(0, 19, [16, 20, 20]).cuda()
- lbs[1, :, :] = 255
-
- logits1 = net1(inten)
- logits1 = F.interpolate(logits1, inten.size()[2:], mode='bilinear')
- logits2 = net2(inten)
- logits2 = F.interpolate(logits2, inten.size()[2:], mode='bilinear')
-
- loss1 = criteria1(logits1, lbs)
- loss2 = criteria2(logits2, lbs)
- loss = loss1 + loss2
- print(loss.detach().cpu())
- loss.backward()
|