落水人员检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

77 line
2.7KB

  1. import torch
  2. import torch.nn.functional as F
  3. from torch.nn import Module
  4. import torch.nn as nn
  5. class SSIMLoss(Module):
  6. def __init__(self, kernel_size=11, sigma=1.5, as_loss=True):
  7. super().__init__()
  8. self.kernel_size = kernel_size
  9. self.sigma = sigma
  10. self.as_loss = as_loss
  11. self.gaussian_kernel = self._create_gaussian_kernel(self.kernel_size, self.sigma)
  12. def forward(self, x, y):
  13. if not self.gaussian_kernel.is_cuda:
  14. self.gaussian_kernel = self.gaussian_kernel.to(x.device)
  15. ssim_map = self._ssim(x, y)
  16. if self.as_loss:
  17. return 1 - ssim_map.mean()
  18. else:
  19. return 1 - ssim_map
  20. def _ssim(self, x, y):
  21. # Compute means
  22. y = y.unsqueeze(1)
  23. #print('line30:',x.shape,self.gaussian_kernel.shape,' Y.shape:',y.shape)
  24. ux = F.conv2d(x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
  25. uy = F.conv2d(y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
  26. # Compute variances
  27. uxx = F.conv2d(x * x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
  28. uyy = F.conv2d(y * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
  29. uxy = F.conv2d(x * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
  30. vx = uxx - ux * ux
  31. vy = uyy - uy * uy
  32. vxy = uxy - ux * uy
  33. c1 = 0.01**2
  34. c2 = 0.03**2
  35. numerator = (2 * ux * uy + c1) * (2 * vxy + c2)
  36. denominator = (ux**2 + uy**2 + c1) * (vx + vy + c2)
  37. return numerator / (denominator + 1e-12)
  38. def _create_gaussian_kernel(self, kernel_size, sigma):
  39. start = (1 - kernel_size) / 2
  40. end = (1 + kernel_size) / 2
  41. kernel_1d = torch.arange(start, end, step=1, dtype=torch.float)
  42. kernel_1d = torch.exp(-torch.pow(kernel_1d / sigma, 2) / 2)
  43. kernel_1d = (kernel_1d / kernel_1d.sum()).unsqueeze(dim=0)
  44. kernel_2d = torch.matmul(kernel_1d.t(), kernel_1d)
  45. kernel_2d = kernel_2d.expand(3, 1, kernel_size, kernel_size).contiguous()
  46. return kernel_2d
  47. class Multi_SSIM_loss(Module):
  48. def __init__(self, window_sizes=[3, 7, 11], sigma=1.5, as_loss=True):
  49. super(Multi_SSIM_loss, self).__init__()
  50. self.losses = list()
  51. for size in window_sizes:
  52. self.losses.append(SSIMLoss(size, sigma, as_loss))
  53. self.losses = nn.ModuleList(self.losses)
  54. def forward(self, img1, img2):
  55. loss_multi = list()
  56. for i in range(len(self.losses)):
  57. loss_multi.append(self.losses[i](img1, img2))
  58. total_loss = torch.stack(loss_multi, 0).sum()
  59. return total_loss