Drowning_Person_Detection/utils/bce_loss.py

77 lines
2.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from torch.autograd import Variable as V
import cv2
import numpy as np
class dice_bce_loss(nn.Module):
def __init__(self, batch=True):
super(dice_bce_loss, self).__init__()
self.batch = batch
self.bce_loss = nn.CrossEntropyLoss()
def soft_dice_coeff(self, y_pred,y_true):
smooth = 0.0 # may change
if self.batch:
i = torch.sum(y_true)
j = torch.sum(y_pred)
intersection = torch.sum(y_true * y_pred)
else:
i = y_true.sum(1).sum(1).sum(1)
j = y_pred.sum(1).sum(1).sum(1)
intersection = (y_true * y_pred).sum(1).sum(1).sum(1)
score = (2. * intersection + smooth) / (i + j + smooth)
#score = (intersection + smooth) / (i + j - intersection + smooth)#iou
return score.mean()
def soft_dice_loss(self, y_pred,y_true):
loss = 1 - self.soft_dice_coeff(y_true, y_pred)
return loss
def __call__(self, y_pred,y_true):
#print(y_true.requires_grady_pred.requires_grad)
a = self.bce_loss(y_pred, y_true)
b = self.soft_dice_loss(y_true, y_pred)
return 1.0* a + 0.0 * b
class dice_loss(nn.Module):
def __init__(self, batch=True):
super(dice_loss, self).__init__()
self.batch = batch
def soft_dice_coeff(self, y_pred,y_true):
smooth = 0.0 # may change
if self.batch:
'''i = torch.sum(y_true)
j = torch.sum(y_pred)
intersection2= y_true * y_pred
intersection = torch.sum(y_true * y_pred)'''
##y_true,y_pred都是index编码。
##step1求取类别0的交集个数
true_zeros = torch.sum(y_true==0)
pred_zeros = torch.sum(y_pred==0)
all = torch.sum((y_true*y_pred)==0)
zeros_cross = true_zeros + pred_zeros - all
##step2,去取交集的数目
cross = torch.sum(y_pred == y_true)
y_true_p = torch.sum(y_true>0 )
y_pred_p = torch.sum(y_pred>0)
i = y_true_p
j = y_pred_p
intersection = cross - zeros_cross
else:
i = y_true.sum(1).sum(1).sum(1)
j = y_pred.sum(1).sum(1).sum(1)
intersection = (y_true * y_pred).sum(1).sum(1)
score = ( 2 * intersection + smooth) / (i + j + smooth)
#score = (intersection + smooth) / (i + j - intersection + smooth)#iou
return score.mean()
def soft_dice_loss(self, y_pred,y_true):
loss = 1 - self.soft_dice_coeff(y_true, y_pred)
return loss
def __call__(self, y_pred,y_true):
b = self.soft_dice_loss(y_true, y_pred)
return b