import torch import torch.nn as nn from torch.autograd import Variable as V import cv2 import numpy as np class dice_bce_loss(nn.Module): def __init__(self, batch=True): super(dice_bce_loss, self).__init__() self.batch = batch self.bce_loss = nn.CrossEntropyLoss() def soft_dice_coeff(self, y_pred,y_true): smooth = 0.0 # may change if self.batch: i = torch.sum(y_true) j = torch.sum(y_pred) intersection = torch.sum(y_true * y_pred) else: i = y_true.sum(1).sum(1).sum(1) j = y_pred.sum(1).sum(1).sum(1) intersection = (y_true * y_pred).sum(1).sum(1).sum(1) score = (2. * intersection + smooth) / (i + j + smooth) #score = (intersection + smooth) / (i + j - intersection + smooth)#iou return score.mean() def soft_dice_loss(self, y_pred,y_true): loss = 1 - self.soft_dice_coeff(y_true, y_pred) return loss def __call__(self, y_pred,y_true): #print(y_true.requires_grad,y_pred.requires_grad); a = self.bce_loss(y_pred, y_true) b = self.soft_dice_loss(y_true, y_pred) return 1.0* a + 0.0 * b class dice_loss(nn.Module): def __init__(self, batch=True): super(dice_loss, self).__init__() self.batch = batch def soft_dice_coeff(self, y_pred,y_true): smooth = 0.0 # may change if self.batch: '''i = torch.sum(y_true) j = torch.sum(y_pred) intersection2= y_true * y_pred intersection = torch.sum(y_true * y_pred)''' ##y_true,y_pred都是index编码。 ##step1,求取类别0的交集个数 true_zeros = torch.sum(y_true==0) pred_zeros = torch.sum(y_pred==0) all = torch.sum((y_true*y_pred)==0) zeros_cross = true_zeros + pred_zeros - all ##step2,去取交集的数目 cross = torch.sum(y_pred == y_true) y_true_p = torch.sum(y_true>0 ) y_pred_p = torch.sum(y_pred>0) i = y_true_p j = y_pred_p intersection = cross - zeros_cross else: i = y_true.sum(1).sum(1).sum(1) j = y_pred.sum(1).sum(1).sum(1) intersection = (y_true * y_pred).sum(1).sum(1) score = ( 2 * intersection + smooth) / (i + j + smooth) #score = (intersection + smooth) / (i + j - intersection + smooth)#iou return score.mean() def soft_dice_loss(self, y_pred,y_true): loss = 1 - self.soft_dice_coeff(y_true, y_pred) return loss def __call__(self, y_pred,y_true): b = self.soft_dice_loss(y_true, y_pred) return b