落水人员检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

77 lines
2.8KB

  1. import torch
  2. import torch.nn as nn
  3. from torch.autograd import Variable as V
  4. import cv2
  5. import numpy as np
  6. class dice_bce_loss(nn.Module):
  7. def __init__(self, batch=True):
  8. super(dice_bce_loss, self).__init__()
  9. self.batch = batch
  10. self.bce_loss = nn.CrossEntropyLoss()
  11. def soft_dice_coeff(self, y_pred,y_true):
  12. smooth = 0.0 # may change
  13. if self.batch:
  14. i = torch.sum(y_true)
  15. j = torch.sum(y_pred)
  16. intersection = torch.sum(y_true * y_pred)
  17. else:
  18. i = y_true.sum(1).sum(1).sum(1)
  19. j = y_pred.sum(1).sum(1).sum(1)
  20. intersection = (y_true * y_pred).sum(1).sum(1).sum(1)
  21. score = (2. * intersection + smooth) / (i + j + smooth)
  22. #score = (intersection + smooth) / (i + j - intersection + smooth)#iou
  23. return score.mean()
  24. def soft_dice_loss(self, y_pred,y_true):
  25. loss = 1 - self.soft_dice_coeff(y_true, y_pred)
  26. return loss
  27. def __call__(self, y_pred,y_true):
  28. #print(y_true.requires_grad,y_pred.requires_grad);
  29. a = self.bce_loss(y_pred, y_true)
  30. b = self.soft_dice_loss(y_true, y_pred)
  31. return 1.0* a + 0.0 * b
  32. class dice_loss(nn.Module):
  33. def __init__(self, batch=True):
  34. super(dice_loss, self).__init__()
  35. self.batch = batch
  36. def soft_dice_coeff(self, y_pred,y_true):
  37. smooth = 0.0 # may change
  38. if self.batch:
  39. '''i = torch.sum(y_true)
  40. j = torch.sum(y_pred)
  41. intersection2= y_true * y_pred
  42. intersection = torch.sum(y_true * y_pred)'''
  43. ##y_true,y_pred都是index编码。
  44. ##step1,求取类别0的交集个数
  45. true_zeros = torch.sum(y_true==0)
  46. pred_zeros = torch.sum(y_pred==0)
  47. all = torch.sum((y_true*y_pred)==0)
  48. zeros_cross = true_zeros + pred_zeros - all
  49. ##step2,去取交集的数目
  50. cross = torch.sum(y_pred == y_true)
  51. y_true_p = torch.sum(y_true>0 )
  52. y_pred_p = torch.sum(y_pred>0)
  53. i = y_true_p
  54. j = y_pred_p
  55. intersection = cross - zeros_cross
  56. else:
  57. i = y_true.sum(1).sum(1).sum(1)
  58. j = y_pred.sum(1).sum(1).sum(1)
  59. intersection = (y_true * y_pred).sum(1).sum(1)
  60. score = ( 2 * intersection + smooth) / (i + j + smooth)
  61. #score = (intersection + smooth) / (i + j - intersection + smooth)#iou
  62. return score.mean()
  63. def soft_dice_loss(self, y_pred,y_true):
  64. loss = 1 - self.soft_dice_coeff(y_true, y_pred)
  65. return loss
  66. def __call__(self, y_pred,y_true):
  67. b = self.soft_dice_loss(y_true, y_pred)
  68. return b