Drowning_Person_Detection/utils/SSIM_loss.py

77 lines
2.7 KiB
Python

import torch
import torch.nn.functional as F
from torch.nn import Module
import torch.nn as nn
class SSIMLoss(Module):
def __init__(self, kernel_size=11, sigma=1.5, as_loss=True):
super().__init__()
self.kernel_size = kernel_size
self.sigma = sigma
self.as_loss = as_loss
self.gaussian_kernel = self._create_gaussian_kernel(self.kernel_size, self.sigma)
def forward(self, x, y):
if not self.gaussian_kernel.is_cuda:
self.gaussian_kernel = self.gaussian_kernel.to(x.device)
ssim_map = self._ssim(x, y)
if self.as_loss:
return 1 - ssim_map.mean()
else:
return 1 - ssim_map
def _ssim(self, x, y):
# Compute means
y = y.unsqueeze(1)
#print('line30:',x.shape,self.gaussian_kernel.shape,' Y.shape:',y.shape)
ux = F.conv2d(x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
uy = F.conv2d(y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
# Compute variances
uxx = F.conv2d(x * x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
uyy = F.conv2d(y * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
uxy = F.conv2d(x * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=1)
vx = uxx - ux * ux
vy = uyy - uy * uy
vxy = uxy - ux * uy
c1 = 0.01**2
c2 = 0.03**2
numerator = (2 * ux * uy + c1) * (2 * vxy + c2)
denominator = (ux**2 + uy**2 + c1) * (vx + vy + c2)
return numerator / (denominator + 1e-12)
def _create_gaussian_kernel(self, kernel_size, sigma):
start = (1 - kernel_size) / 2
end = (1 + kernel_size) / 2
kernel_1d = torch.arange(start, end, step=1, dtype=torch.float)
kernel_1d = torch.exp(-torch.pow(kernel_1d / sigma, 2) / 2)
kernel_1d = (kernel_1d / kernel_1d.sum()).unsqueeze(dim=0)
kernel_2d = torch.matmul(kernel_1d.t(), kernel_1d)
kernel_2d = kernel_2d.expand(3, 1, kernel_size, kernel_size).contiguous()
return kernel_2d
class Multi_SSIM_loss(Module):
def __init__(self, window_sizes=[3, 7, 11], sigma=1.5, as_loss=True):
super(Multi_SSIM_loss, self).__init__()
self.losses = list()
for size in window_sizes:
self.losses.append(SSIMLoss(size, sigma, as_loss))
self.losses = nn.ModuleList(self.losses)
def forward(self, img1, img2):
loss_multi = list()
for i in range(len(self.losses)):
loss_multi.append(self.losses[i](img1, img2))
total_loss = torch.stack(loss_multi, 0).sum()
return total_loss