import torch import torch.nn.functional as F from torch.nn import Module import torch.nn as nn class SSIMLoss(Module): def __init__(self, kernel_size=11, sigma=1.5, as_loss=True): super().__init__() self.kernel_size = kernel_size self.sigma = sigma self.as_loss = as_loss self.gaussian_kernel = self._create_gaussian_kernel(self.kernel_size, self.sigma) def forward(self, x, y): if not self.gaussian_kernel.is_cuda: self.gaussian_kernel = self.gaussian_kernel.to(x.device) ssim_map = self._ssim(x, y) if self.as_loss: return 1 - ssim_map.mean() else: return 1 - ssim_map def _ssim(self, x, y): # Compute means y = y.unsqueeze(1) #print('line30:',x.shape,self.gaussian_kernel.shape,' Y.shape:',y.shape) ux = F.conv2d(x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1) uy = F.conv2d(y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1) # Compute variances uxx = F.conv2d(x * x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1) uyy = F.conv2d(y * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1) uxy = F.conv2d(x * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1) vx = uxx - ux * ux vy = uyy - uy * uy vxy = uxy - ux * uy c1 = 0.01**2 c2 = 0.03**2 numerator = (2 * ux * uy + c1) * (2 * vxy + c2) denominator = (ux**2 + uy**2 + c1) * (vx + vy + c2) return numerator / (denominator + 1e-12) def _create_gaussian_kernel(self, kernel_size, sigma): start = (1 - kernel_size) / 2 end = (1 + kernel_size) / 2 kernel_1d = torch.arange(start, end, step=1, dtype=torch.float) kernel_1d = torch.exp(-torch.pow(kernel_1d / sigma, 2) / 2) kernel_1d = (kernel_1d / kernel_1d.sum()).unsqueeze(dim=0) kernel_2d = torch.matmul(kernel_1d.t(), kernel_1d) kernel_2d = kernel_2d.expand(3, 1, kernel_size, kernel_size).contiguous() return kernel_2d class Multi_SSIM_loss(Module): def __init__(self, window_sizes=[3, 7, 11], sigma=1.5, as_loss=True): super(Multi_SSIM_loss, self).__init__() self.losses = list() for size in window_sizes: self.losses.append(SSIMLoss(size, sigma, as_loss)) self.losses = nn.ModuleList(self.losses) def forward(self, img1, img2): loss_multi = list() for i in range(len(self.losses)): loss_multi.append(self.losses[i](img1, img2)) total_loss = torch.stack(loss_multi, 0).sum() return total_loss