77 lines
2.8 KiB
Python
77 lines
2.8 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
from torch.autograd import Variable as V
|
||
|
||
import cv2
|
||
import numpy as np
|
||
class dice_bce_loss(nn.Module):
|
||
def __init__(self, batch=True):
|
||
super(dice_bce_loss, self).__init__()
|
||
self.batch = batch
|
||
self.bce_loss = nn.CrossEntropyLoss()
|
||
|
||
def soft_dice_coeff(self, y_pred,y_true):
|
||
smooth = 0.0 # may change
|
||
if self.batch:
|
||
i = torch.sum(y_true)
|
||
j = torch.sum(y_pred)
|
||
intersection = torch.sum(y_true * y_pred)
|
||
else:
|
||
i = y_true.sum(1).sum(1).sum(1)
|
||
j = y_pred.sum(1).sum(1).sum(1)
|
||
intersection = (y_true * y_pred).sum(1).sum(1).sum(1)
|
||
score = (2. * intersection + smooth) / (i + j + smooth)
|
||
#score = (intersection + smooth) / (i + j - intersection + smooth)#iou
|
||
return score.mean()
|
||
|
||
def soft_dice_loss(self, y_pred,y_true):
|
||
loss = 1 - self.soft_dice_coeff(y_true, y_pred)
|
||
return loss
|
||
|
||
def __call__(self, y_pred,y_true):
|
||
#print(y_true.requires_grad,y_pred.requires_grad);
|
||
a = self.bce_loss(y_pred, y_true)
|
||
b = self.soft_dice_loss(y_true, y_pred)
|
||
return 1.0* a + 0.0 * b
|
||
class dice_loss(nn.Module):
|
||
def __init__(self, batch=True):
|
||
super(dice_loss, self).__init__()
|
||
self.batch = batch
|
||
def soft_dice_coeff(self, y_pred,y_true):
|
||
smooth = 0.0 # may change
|
||
if self.batch:
|
||
'''i = torch.sum(y_true)
|
||
j = torch.sum(y_pred)
|
||
intersection2= y_true * y_pred
|
||
intersection = torch.sum(y_true * y_pred)'''
|
||
##y_true,y_pred都是index编码。
|
||
##step1,求取类别0的交集个数
|
||
true_zeros = torch.sum(y_true==0)
|
||
pred_zeros = torch.sum(y_pred==0)
|
||
all = torch.sum((y_true*y_pred)==0)
|
||
zeros_cross = true_zeros + pred_zeros - all
|
||
|
||
##step2,去取交集的数目
|
||
cross = torch.sum(y_pred == y_true)
|
||
y_true_p = torch.sum(y_true>0 )
|
||
y_pred_p = torch.sum(y_pred>0)
|
||
|
||
i = y_true_p
|
||
j = y_pred_p
|
||
intersection = cross - zeros_cross
|
||
else:
|
||
i = y_true.sum(1).sum(1).sum(1)
|
||
j = y_pred.sum(1).sum(1).sum(1)
|
||
intersection = (y_true * y_pred).sum(1).sum(1)
|
||
|
||
score = ( 2 * intersection + smooth) / (i + j + smooth)
|
||
#score = (intersection + smooth) / (i + j - intersection + smooth)#iou
|
||
return score.mean()
|
||
|
||
def soft_dice_loss(self, y_pred,y_true):
|
||
loss = 1 - self.soft_dice_coeff(y_true, y_pred)
|
||
return loss
|
||
|
||
def __call__(self, y_pred,y_true):
|
||
b = self.soft_dice_loss(y_true, y_pred)
|
||
return b |