|
- import torch
- import torch.nn.functional as F
- from torch.nn import Module
- import torch.nn as nn
-
-
- class SSIMLoss(Module):
- def __init__(self, kernel_size=11, sigma=1.5, as_loss=True):
- super().__init__()
- self.kernel_size = kernel_size
- self.sigma = sigma
- self.as_loss = as_loss
- self.gaussian_kernel = self._create_gaussian_kernel(self.kernel_size, self.sigma)
-
- def forward(self, x, y):
-
- if not self.gaussian_kernel.is_cuda:
- self.gaussian_kernel = self.gaussian_kernel.to(x.device)
-
- ssim_map = self._ssim(x, y)
-
- if self.as_loss:
- return 1 - ssim_map.mean()
- else:
- return 1 - ssim_map
-
- def _ssim(self, x, y):
-
- # Compute means
- y = y.unsqueeze(1)
- #print('line30:',x.shape,self.gaussian_kernel.shape,' Y.shape:',y.shape)
- ux = F.conv2d(x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
- uy = F.conv2d(y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
-
- # Compute variances
- uxx = F.conv2d(x * x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
- uyy = F.conv2d(y * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
- uxy = F.conv2d(x * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
- vx = uxx - ux * ux
- vy = uyy - uy * uy
- vxy = uxy - ux * uy
-
- c1 = 0.01**2
- c2 = 0.03**2
- numerator = (2 * ux * uy + c1) * (2 * vxy + c2)
- denominator = (ux**2 + uy**2 + c1) * (vx + vy + c2)
- return numerator / (denominator + 1e-12)
-
- def _create_gaussian_kernel(self, kernel_size, sigma):
-
- start = (1 - kernel_size) / 2
- end = (1 + kernel_size) / 2
- kernel_1d = torch.arange(start, end, step=1, dtype=torch.float)
- kernel_1d = torch.exp(-torch.pow(kernel_1d / sigma, 2) / 2)
- kernel_1d = (kernel_1d / kernel_1d.sum()).unsqueeze(dim=0)
-
- kernel_2d = torch.matmul(kernel_1d.t(), kernel_1d)
- kernel_2d = kernel_2d.expand(3, 1, kernel_size, kernel_size).contiguous()
- return kernel_2d
-
-
- class Multi_SSIM_loss(Module):
- def __init__(self, window_sizes=[3, 7, 11], sigma=1.5, as_loss=True):
- super(Multi_SSIM_loss, self).__init__()
- self.losses = list()
- for size in window_sizes:
- self.losses.append(SSIMLoss(size, sigma, as_loss))
- self.losses = nn.ModuleList(self.losses)
-
- def forward(self, img1, img2):
- loss_multi = list()
- for i in range(len(self.losses)):
- loss_multi.append(self.losses[i](img1, img2))
- total_loss = torch.stack(loss_multi, 0).sum()
-
- return total_loss
|