152 lines
5.7 KiB
Python
152 lines
5.7 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from torch.nn import Module
|
|
import torch
|
|
import torch.nn as nn
|
|
class wj2_bce_Loss(Module):
|
|
def __init__(self, kernel_size=11, sigma=2,weights=[0.2,1.0,1.0], as_loss=True):
|
|
super().__init__()
|
|
self.kernel_size = kernel_size
|
|
self.sigma = sigma
|
|
self.weights = torch.tensor(weights)
|
|
self.as_loss = as_loss
|
|
self.gaussian_kernel = self._create_gaussian_kernel(self.kernel_size, self.sigma)
|
|
self.sobel_kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]])
|
|
self.pad = int((self.kernel_size - 1)/2)
|
|
self.tmax = torch.tensor(1.0)
|
|
#self.criterion = nn.CrossEntropyLoss( reduction = 'none')
|
|
self.criterion1 = nn.CrossEntropyLoss( )
|
|
self.criterion2 = nn.MSELoss()
|
|
def forward(self, x, y):
|
|
|
|
if not self.gaussian_kernel.is_cuda:
|
|
self.gaussian_kernel = self.gaussian_kernel.to(x.device)
|
|
if not self.sobel_kernel.is_cuda:
|
|
self.sobel_kernel = self.sobel_kernel.to(x.device)
|
|
if not self.tmax.is_cuda:
|
|
self.tmax = self.tmax.to(x.device)
|
|
if not self.weights.is_cuda:
|
|
self.weights = self.weights.to(x.device)
|
|
|
|
#get pred weight
|
|
preds_x = torch.argmax(x,axis=1)
|
|
preds_edge = self._get_weight(preds_x)##(bs,1,h,w)
|
|
#get label weight
|
|
labels_edge = self._get_weight(y) ##(bs,1,h,w)
|
|
|
|
celoss = self.criterion1(x,y.long())
|
|
edgeloss = self.criterion2(preds_edge, labels_edge)
|
|
|
|
return self.weights[0]*edgeloss + self.weights[1]*celoss
|
|
|
|
def _get_weight(self,mask):
|
|
##preds 变成0,1图,(bs,h,w)
|
|
mask_map = (mask <= self.tmax).float() * mask + (mask > self.tmax).float() * self.tmax
|
|
mask_map = mask_map.unsqueeze(1)
|
|
padLayer = nn.ReflectionPad2d(1)
|
|
mask_pad = padLayer(mask_map)
|
|
|
|
# 定义sobel算子参数
|
|
mask_edge = torch.conv2d(mask_pad.float(), self.sobel_kernel.float(), padding=0)
|
|
mask_edge = torch.absolute(mask_edge)
|
|
|
|
##低通滤波膨胀边界
|
|
smooth_edge = torch.conv2d(mask_edge.float(), self.gaussian_kernel.float(), padding=self.pad)
|
|
return smooth_edge
|
|
|
|
def _create_gaussian_kernel(self, kernel_size, sigma):
|
|
start = (1 - kernel_size) / 2
|
|
end = (1 + kernel_size) / 2
|
|
kernel_1d = torch.arange(start, end, step=1, dtype=torch.float)
|
|
kernel_1d = torch.exp(-torch.pow(kernel_1d / sigma, 2) / 2)
|
|
kernel_1d = (kernel_1d / kernel_1d.sum()).unsqueeze(dim=0)
|
|
|
|
kernel_2d = torch.matmul(kernel_1d.t(), kernel_1d)
|
|
kernel_2d = kernel_2d.expand(1, 1, kernel_size, kernel_size).contiguous()
|
|
return kernel_2d
|
|
|
|
|
|
def GaussLowPassFiltering(ksize,sigma):
|
|
kernel = np.zeros((ksize,ksize),dtype=np.float32)
|
|
cons = 1.0/(2.0*np.pi*sigma*sigma)
|
|
|
|
for i in range(ksize):
|
|
for j in range(ksize):
|
|
x = i - (ksize-1)/2
|
|
y = j - (ksize-1)/2
|
|
kernel[j,i] = cons * np.exp((-1.0)*(x**2+y**2)/2.0/(sigma**2) )
|
|
print(kernel)
|
|
plt.figure(0);plt.imshow(kernel);plt.show()
|
|
return kernel.reshape(1,1,ksize,ksize)
|
|
|
|
def create_gaussian_kernel( kernel_size, sigma):
|
|
|
|
start = (1 - kernel_size) / 2
|
|
end = (1 + kernel_size) / 2
|
|
kernel_1d = torch.arange(start, end, step=1, dtype=torch.float)
|
|
kernel_1d = torch.exp(-torch.pow(kernel_1d / sigma, 2) / 2)
|
|
kernel_1d = (kernel_1d / kernel_1d.sum()).unsqueeze(dim=0)
|
|
|
|
kernel_2d = torch.matmul(kernel_1d.t(), kernel_1d)
|
|
kernel_2d = kernel_2d.expand(3, 1, kernel_size, kernel_size).contiguous()
|
|
return kernel_2d
|
|
|
|
def main():
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
#preds=torch.rand(8,5,10,10)
|
|
#preds=torch.argmax(preds,axis=1)
|
|
preds=torch.zeros(8,100,100)
|
|
preds[:,:,50:]=3.0
|
|
t_max = torch.tensor(1.0)
|
|
|
|
|
|
##preds 变成0,1图,(bs,h,w)
|
|
preds_map = (preds <= t_max).float() * preds + (preds > t_max).float() * t_max
|
|
|
|
preds_map = preds_map.unsqueeze(1)
|
|
padLayer = nn.ReflectionPad2d(1)
|
|
preds_pad = padLayer(preds_map)
|
|
# 定义sobel算子参数
|
|
sobel_kernel =torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]])
|
|
preds_edge_dilate = torch.conv2d(preds_pad.float(), sobel_kernel.float(), padding=0)
|
|
preds_edge_dilate = torch.absolute(preds_edge_dilate)
|
|
|
|
##低通滤波,平滑边界
|
|
f_shift ,pad, sigma= 11, 5 , 2
|
|
|
|
kernel = torch.from_numpy(GaussLowPassFiltering(f_shift,sigma))
|
|
smooth_edge = torch.conv2d(preds_edge_dilate.float(), kernel.float(), padding=pad)
|
|
print()
|
|
|
|
show_result0 = preds_map.numpy()
|
|
show_result2=smooth_edge.numpy()
|
|
show_result3=preds.numpy()
|
|
#print(show_result2[0,0,:,5])
|
|
#print(show_result2[0,0,5,:])
|
|
#plt.figure(0);plt.imshow(show_result0[0,0]);plt.figure(1);
|
|
plt.imshow(show_result2[0,0]);plt.show();
|
|
#plt.figure(3);plt.imshow(show_result3[0]);plt.show();
|
|
print()
|
|
def test_loss_moule():
|
|
preds=torch.rand(8,5,100,100)
|
|
#preds=torch.argmax(preds,axis=1)
|
|
|
|
targets =torch.zeros(8,100,100)
|
|
targets[:,:,50:]=3.0
|
|
|
|
for weights in [[1.0,1.0,1.0],[ 0.0,0.0,1.0],[ 1.0,0.0,0.0],[ 0.0,1.0,0.0],[ 1.0,1.0,0.0] ]:
|
|
loss_layer = wj_bce_Loss(kernel_size=11, sigma=2,weights=weights, as_loss=True)
|
|
loss = loss_layer(preds,targets)
|
|
print(weights,' loss: ',loss)
|
|
|
|
|
|
if __name__=='__main__':
|
|
#main()
|
|
#kk = create_gaussian_kernel( kernel_size=11, sigma=2)
|
|
#print(kk.numpy().shape)
|
|
#plt.figure(0);plt.imshow(kk[0,0]);plt.show()
|
|
test_loss_moule()
|