77 lines
2.7 KiB
Python
77 lines
2.7 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from torch.nn import Module
|
|
import torch.nn as nn
|
|
|
|
|
|
class SSIMLoss(Module):
|
|
def __init__(self, kernel_size=11, sigma=1.5, as_loss=True):
|
|
super().__init__()
|
|
self.kernel_size = kernel_size
|
|
self.sigma = sigma
|
|
self.as_loss = as_loss
|
|
self.gaussian_kernel = self._create_gaussian_kernel(self.kernel_size, self.sigma)
|
|
|
|
def forward(self, x, y):
|
|
|
|
if not self.gaussian_kernel.is_cuda:
|
|
self.gaussian_kernel = self.gaussian_kernel.to(x.device)
|
|
|
|
ssim_map = self._ssim(x, y)
|
|
|
|
if self.as_loss:
|
|
return 1 - ssim_map.mean()
|
|
else:
|
|
return 1 - ssim_map
|
|
|
|
def _ssim(self, x, y):
|
|
|
|
# Compute means
|
|
y = y.unsqueeze(1)
|
|
#print('line30:',x.shape,self.gaussian_kernel.shape,' Y.shape:',y.shape)
|
|
ux = F.conv2d(x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
|
|
uy = F.conv2d(y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
|
|
|
|
# Compute variances
|
|
uxx = F.conv2d(x * x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
|
|
uyy = F.conv2d(y * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
|
|
uxy = F.conv2d(x * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
|
|
vx = uxx - ux * ux
|
|
vy = uyy - uy * uy
|
|
vxy = uxy - ux * uy
|
|
|
|
c1 = 0.01**2
|
|
c2 = 0.03**2
|
|
numerator = (2 * ux * uy + c1) * (2 * vxy + c2)
|
|
denominator = (ux**2 + uy**2 + c1) * (vx + vy + c2)
|
|
return numerator / (denominator + 1e-12)
|
|
|
|
def _create_gaussian_kernel(self, kernel_size, sigma):
|
|
|
|
start = (1 - kernel_size) / 2
|
|
end = (1 + kernel_size) / 2
|
|
kernel_1d = torch.arange(start, end, step=1, dtype=torch.float)
|
|
kernel_1d = torch.exp(-torch.pow(kernel_1d / sigma, 2) / 2)
|
|
kernel_1d = (kernel_1d / kernel_1d.sum()).unsqueeze(dim=0)
|
|
|
|
kernel_2d = torch.matmul(kernel_1d.t(), kernel_1d)
|
|
kernel_2d = kernel_2d.expand(3, 1, kernel_size, kernel_size).contiguous()
|
|
return kernel_2d
|
|
|
|
|
|
class Multi_SSIM_loss(Module):
|
|
def __init__(self, window_sizes=[3, 7, 11], sigma=1.5, as_loss=True):
|
|
super(Multi_SSIM_loss, self).__init__()
|
|
self.losses = list()
|
|
for size in window_sizes:
|
|
self.losses.append(SSIMLoss(size, sigma, as_loss))
|
|
self.losses = nn.ModuleList(self.losses)
|
|
|
|
def forward(self, img1, img2):
|
|
loss_multi = list()
|
|
for i in range(len(self.losses)):
|
|
loss_multi.append(self.losses[i](img1, img2))
|
|
total_loss = torch.stack(loss_multi, 0).sum()
|
|
|
|
return total_loss
|