落水人员检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

152 lines
5.7KB

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from torch.nn import Module
  4. import torch
  5. import torch.nn as nn
  6. class wj2_bce_Loss(Module):
  7. def __init__(self, kernel_size=11, sigma=2,weights=[0.2,1.0,1.0], as_loss=True):
  8. super().__init__()
  9. self.kernel_size = kernel_size
  10. self.sigma = sigma
  11. self.weights = torch.tensor(weights)
  12. self.as_loss = as_loss
  13. self.gaussian_kernel = self._create_gaussian_kernel(self.kernel_size, self.sigma)
  14. self.sobel_kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]])
  15. self.pad = int((self.kernel_size - 1)/2)
  16. self.tmax = torch.tensor(1.0)
  17. #self.criterion = nn.CrossEntropyLoss( reduction = 'none')
  18. self.criterion1 = nn.CrossEntropyLoss( )
  19. self.criterion2 = nn.MSELoss()
  20. def forward(self, x, y):
  21. if not self.gaussian_kernel.is_cuda:
  22. self.gaussian_kernel = self.gaussian_kernel.to(x.device)
  23. if not self.sobel_kernel.is_cuda:
  24. self.sobel_kernel = self.sobel_kernel.to(x.device)
  25. if not self.tmax.is_cuda:
  26. self.tmax = self.tmax.to(x.device)
  27. if not self.weights.is_cuda:
  28. self.weights = self.weights.to(x.device)
  29. #get pred weight
  30. preds_x = torch.argmax(x,axis=1)
  31. preds_edge = self._get_weight(preds_x)##(bs,1,h,w)
  32. #get label weight
  33. labels_edge = self._get_weight(y) ##(bs,1,h,w)
  34. celoss = self.criterion1(x,y.long())
  35. edgeloss = self.criterion2(preds_edge, labels_edge)
  36. return self.weights[0]*edgeloss + self.weights[1]*celoss
  37. def _get_weight(self,mask):
  38. ##preds 变成0,1图,(bs,h,w)
  39. mask_map = (mask <= self.tmax).float() * mask + (mask > self.tmax).float() * self.tmax
  40. mask_map = mask_map.unsqueeze(1)
  41. padLayer = nn.ReflectionPad2d(1)
  42. mask_pad = padLayer(mask_map)
  43. # 定义sobel算子参数
  44. mask_edge = torch.conv2d(mask_pad.float(), self.sobel_kernel.float(), padding=0)
  45. mask_edge = torch.absolute(mask_edge)
  46. ##低通滤波膨胀边界
  47. smooth_edge = torch.conv2d(mask_edge.float(), self.gaussian_kernel.float(), padding=self.pad)
  48. return smooth_edge
  49. def _create_gaussian_kernel(self, kernel_size, sigma):
  50. start = (1 - kernel_size) / 2
  51. end = (1 + kernel_size) / 2
  52. kernel_1d = torch.arange(start, end, step=1, dtype=torch.float)
  53. kernel_1d = torch.exp(-torch.pow(kernel_1d / sigma, 2) / 2)
  54. kernel_1d = (kernel_1d / kernel_1d.sum()).unsqueeze(dim=0)
  55. kernel_2d = torch.matmul(kernel_1d.t(), kernel_1d)
  56. kernel_2d = kernel_2d.expand(1, 1, kernel_size, kernel_size).contiguous()
  57. return kernel_2d
  58. def GaussLowPassFiltering(ksize,sigma):
  59. kernel = np.zeros((ksize,ksize),dtype=np.float32)
  60. cons = 1.0/(2.0*np.pi*sigma*sigma)
  61. for i in range(ksize):
  62. for j in range(ksize):
  63. x = i - (ksize-1)/2
  64. y = j - (ksize-1)/2
  65. kernel[j,i] = cons * np.exp((-1.0)*(x**2+y**2)/2.0/(sigma**2) )
  66. print(kernel)
  67. plt.figure(0);plt.imshow(kernel);plt.show()
  68. return kernel.reshape(1,1,ksize,ksize)
  69. def create_gaussian_kernel( kernel_size, sigma):
  70. start = (1 - kernel_size) / 2
  71. end = (1 + kernel_size) / 2
  72. kernel_1d = torch.arange(start, end, step=1, dtype=torch.float)
  73. kernel_1d = torch.exp(-torch.pow(kernel_1d / sigma, 2) / 2)
  74. kernel_1d = (kernel_1d / kernel_1d.sum()).unsqueeze(dim=0)
  75. kernel_2d = torch.matmul(kernel_1d.t(), kernel_1d)
  76. kernel_2d = kernel_2d.expand(3, 1, kernel_size, kernel_size).contiguous()
  77. return kernel_2d
  78. def main():
  79. import matplotlib.pyplot as plt
  80. import numpy as np
  81. import torch
  82. import torch.nn as nn
  83. #preds=torch.rand(8,5,10,10)
  84. #preds=torch.argmax(preds,axis=1)
  85. preds=torch.zeros(8,100,100)
  86. preds[:,:,50:]=3.0
  87. t_max = torch.tensor(1.0)
  88. ##preds 变成0,1图,(bs,h,w)
  89. preds_map = (preds <= t_max).float() * preds + (preds > t_max).float() * t_max
  90. preds_map = preds_map.unsqueeze(1)
  91. padLayer = nn.ReflectionPad2d(1)
  92. preds_pad = padLayer(preds_map)
  93. # 定义sobel算子参数
  94. sobel_kernel =torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]])
  95. preds_edge_dilate = torch.conv2d(preds_pad.float(), sobel_kernel.float(), padding=0)
  96. preds_edge_dilate = torch.absolute(preds_edge_dilate)
  97. ##低通滤波,平滑边界
  98. f_shift ,pad, sigma= 11, 5 , 2
  99. kernel = torch.from_numpy(GaussLowPassFiltering(f_shift,sigma))
  100. smooth_edge = torch.conv2d(preds_edge_dilate.float(), kernel.float(), padding=pad)
  101. print()
  102. show_result0 = preds_map.numpy()
  103. show_result2=smooth_edge.numpy()
  104. show_result3=preds.numpy()
  105. #print(show_result2[0,0,:,5])
  106. #print(show_result2[0,0,5,:])
  107. #plt.figure(0);plt.imshow(show_result0[0,0]);plt.figure(1);
  108. plt.imshow(show_result2[0,0]);plt.show();
  109. #plt.figure(3);plt.imshow(show_result3[0]);plt.show();
  110. print()
  111. def test_loss_moule():
  112. preds=torch.rand(8,5,100,100)
  113. #preds=torch.argmax(preds,axis=1)
  114. targets =torch.zeros(8,100,100)
  115. targets[:,:,50:]=3.0
  116. for weights in [[1.0,1.0,1.0],[ 0.0,0.0,1.0],[ 1.0,0.0,0.0],[ 0.0,1.0,0.0],[ 1.0,1.0,0.0] ]:
  117. loss_layer = wj_bce_Loss(kernel_size=11, sigma=2,weights=weights, as_loss=True)
  118. loss = loss_layer(preds,targets)
  119. print(weights,' loss: ',loss)
  120. if __name__=='__main__':
  121. #main()
  122. #kk = create_gaussian_kernel( kernel_size=11, sigma=2)
  123. #print(kk.numpy().shape)
  124. #plt.figure(0);plt.imshow(kk[0,0]);plt.show()
  125. test_loss_moule()