|
- import torch
- import torch.nn as nn
- from torch.autograd import Variable as V
-
- import cv2
- import numpy as np
- class dice_bce_loss(nn.Module):
- def __init__(self, batch=True):
- super(dice_bce_loss, self).__init__()
- self.batch = batch
- self.bce_loss = nn.CrossEntropyLoss()
-
- def soft_dice_coeff(self, y_pred,y_true):
- smooth = 0.0 # may change
- if self.batch:
- i = torch.sum(y_true)
- j = torch.sum(y_pred)
- intersection = torch.sum(y_true * y_pred)
- else:
- i = y_true.sum(1).sum(1).sum(1)
- j = y_pred.sum(1).sum(1).sum(1)
- intersection = (y_true * y_pred).sum(1).sum(1).sum(1)
- score = (2. * intersection + smooth) / (i + j + smooth)
- #score = (intersection + smooth) / (i + j - intersection + smooth)#iou
- return score.mean()
-
- def soft_dice_loss(self, y_pred,y_true):
- loss = 1 - self.soft_dice_coeff(y_true, y_pred)
- return loss
-
- def __call__(self, y_pred,y_true):
- #print(y_true.requires_grad,y_pred.requires_grad);
- a = self.bce_loss(y_pred, y_true)
- b = self.soft_dice_loss(y_true, y_pred)
- return 1.0* a + 0.0 * b
- class dice_loss(nn.Module):
- def __init__(self, batch=True):
- super(dice_loss, self).__init__()
- self.batch = batch
- def soft_dice_coeff(self, y_pred,y_true):
- smooth = 0.0 # may change
- if self.batch:
- '''i = torch.sum(y_true)
- j = torch.sum(y_pred)
- intersection2= y_true * y_pred
- intersection = torch.sum(y_true * y_pred)'''
- ##y_true,y_pred都是index编码。
- ##step1,求取类别0的交集个数
- true_zeros = torch.sum(y_true==0)
- pred_zeros = torch.sum(y_pred==0)
- all = torch.sum((y_true*y_pred)==0)
- zeros_cross = true_zeros + pred_zeros - all
-
- ##step2,去取交集的数目
- cross = torch.sum(y_pred == y_true)
- y_true_p = torch.sum(y_true>0 )
- y_pred_p = torch.sum(y_pred>0)
-
- i = y_true_p
- j = y_pred_p
- intersection = cross - zeros_cross
- else:
- i = y_true.sum(1).sum(1).sum(1)
- j = y_pred.sum(1).sum(1).sum(1)
- intersection = (y_true * y_pred).sum(1).sum(1)
-
- score = ( 2 * intersection + smooth) / (i + j + smooth)
- #score = (intersection + smooth) / (i + j - intersection + smooth)#iou
- return score.mean()
-
- def soft_dice_loss(self, y_pred,y_true):
- loss = 1 - self.soft_dice_coeff(y_true, y_pred)
- return loss
-
- def __call__(self, y_pred,y_true):
- b = self.soft_dice_loss(y_true, y_pred)
- return b
|