import numpy as np import matplotlib.pyplot as plt from torch.nn import Module import torch import torch.nn as nn class wj_bce_Loss(Module): def __init__(self, kernel_size=11, sigma=2,weights=[0.2,1.0,1.0], weight_fuse='add',classweight=None,as_loss=True): super().__init__() self.kernel_size = kernel_size self.sigma = sigma self.weights = torch.tensor(weights) self.as_loss = as_loss self.gaussian_kernel = self._create_gaussian_kernel(self.kernel_size, self.sigma) self.sobel_kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]]) self.pad = int((self.kernel_size - 1)/2) self.tmax = torch.tensor(1.0) self.weight_fuse = weight_fuse if classweight: self.criterion = nn.CrossEntropyLoss( weight=torch.tensor(classweight),reduction = 'none') else: self.criterion = nn.CrossEntropyLoss( reduction = 'none') def forward(self, x, y): if not self.gaussian_kernel.is_cuda: self.gaussian_kernel = self.gaussian_kernel.to(x.device) if not self.sobel_kernel.is_cuda: self.sobel_kernel = self.sobel_kernel.to(x.device) if not self.tmax.is_cuda: self.tmax = self.tmax.to(x.device) if not self.weights.is_cuda: self.weights = self.weights.to(x.device) mix_weights = self._get_mix_weight(x, y) loss = self.criterion(x,y.long()) return torch.mean(loss*mix_weights) def _get_weight(self,mask): ##preds 变成0,1图,(bs,h,w) mask_map = (mask <= self.tmax).float() * mask + (mask > self.tmax).float() * self.tmax mask_map = mask_map.unsqueeze(1) padLayer = nn.ReflectionPad2d(1) mask_pad = padLayer(mask_map) # 定义sobel算子参数 mask_edge = torch.conv2d(mask_pad.float(), self.sobel_kernel.float(), padding=0) mask_edge = torch.absolute(mask_edge) ##低通滤波膨胀边界 smooth_edge = torch.conv2d(mask_edge.float(), self.gaussian_kernel.float(), padding=self.pad) return smooth_edge def _get_mix_weight(self, x, y): #get pred weight preds_x = torch.argmax(x,axis=1) preds_weights = self._get_weight(preds_x).squeeze(1) #get label weight labels_weights = self._get_weight(y).squeeze(1) #normal weight normal_weights = torch.ones(y.shape) if not normal_weights.is_cuda: normal_weights = normal_weights.to(x.device) #print(self.weights) if self.weight_fuse=='multify': mix_weights = self.weights[0] * preds_weights * labels_weights + self.weights[2] *normal_weights else: mix_weights = self.weights[0] * preds_weights + self.weights[1] * labels_weights + self.weights[2] *normal_weights return mix_weights def _create_gaussian_kernel(self, kernel_size, sigma): start = (1 - kernel_size) / 2 end = (1 + kernel_size) / 2 kernel_1d = torch.arange(start, end, step=1, dtype=torch.float) kernel_1d = torch.exp(-torch.pow(kernel_1d / sigma, 2) / 2) kernel_1d = (kernel_1d / kernel_1d.sum()).unsqueeze(dim=0) kernel_2d = torch.matmul(kernel_1d.t(), kernel_1d) kernel_2d = kernel_2d.expand(1, 1, kernel_size, kernel_size).contiguous() return kernel_2d class thFloater(Module): def __init__(self, weights=[0.5,0.5]): super().__init__() self.weights = torch.tensor(weights) self.baseCriterion = nn.CrossEntropyLoss( reduction = 'none') def forward(self, x, y): if not self.weights.is_cuda: self.weights = self.weights.to(x[0].device) assert len(x) == 2 loss_river = self.baseCriterion(x[0],y[0].long()) #loss_floater = self.baseCriterion(x[1],y[1].long()) * y[0] loss_floater = self.baseCriterion(x[1],y[1].long()) return torch.mean(loss_river * self.weights[0] + loss_floater * self.weights[1] ) def GaussLowPassFiltering(ksize,sigma): kernel = np.zeros((ksize,ksize),dtype=np.float32) cons = 1.0/(2.0*np.pi*sigma*sigma) for i in range(ksize): for j in range(ksize): x = i - (ksize-1)/2 y = j - (ksize-1)/2 kernel[j,i] = cons * np.exp((-1.0)*(x**2+y**2)/2.0/(sigma**2) ) return kernel.reshape(1,1,ksize,ksize) def create_gaussian_kernel( kernel_size, sigma): start = (1 - kernel_size) / 2 end = (1 + kernel_size) / 2 kernel_1d = torch.arange(start, end, step=1, dtype=torch.float) kernel_1d = torch.exp(-torch.pow(kernel_1d / sigma, 2) / 2) kernel_1d = (kernel_1d / kernel_1d.sum()).unsqueeze(dim=0) kernel_2d = torch.matmul(kernel_1d.t(), kernel_1d) kernel_2d = kernel_2d.expand(3, 1, kernel_size, kernel_size).contiguous() return kernel_2d def main(): import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn #preds=torch.rand(8,5,10,10) #preds=torch.argmax(preds,axis=1) preds=torch.zeros(8,100,100) preds[:,:,50:]=3.0 t_max = torch.tensor(1.0) ##preds 变成0,1图,(bs,h,w) preds_map = (preds <= t_max).float() * preds + (preds > t_max).float() * t_max preds_map = preds_map.unsqueeze(1) padLayer = nn.ReflectionPad2d(1) preds_pad = padLayer(preds_map) # 定义sobel算子参数 sobel_kernel =torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]]) preds_edge_dilate = torch.conv2d(preds_pad.float(), sobel_kernel.float(), padding=0) preds_edge_dilate = torch.absolute(preds_edge_dilate) ##低通滤波,平滑边界 f_shift ,pad, sigma= 11, 5 , 2 kernel = torch.from_numpy(GaussLowPassFiltering(f_shift,sigma)) smooth_edge = torch.conv2d(preds_edge_dilate.float(), kernel.float(), padding=pad) print('####line134') cv2.imwrite('') show_result0 = preds_map.numpy() show_result2 = smooth_edge.numpy() show_result3=preds.numpy() #print(show_result2[0,0,:,5]) #print(show_result2[0,0,5,:]) #plt.figure(0);plt.imshow(show_result0[0,0]);plt.figure(1); plt.imshow(show_result2[0,0]);plt.show(); #plt.figure(3);plt.imshow(show_result3[0]);plt.show(); print() def test_loss_moule(): preds=torch.rand(8,5,100,100) #preds=torch.argmax(preds,axis=1) targets =torch.zeros(8,100,100) targets[:,:,50:]=3.0 for weights in [[1.0,1.0,1.0],[ 0.0,0.0,1.0],[ 1.0,0.0,0.0],[ 0.0,1.0,0.0],[ 1.0,1.0,0.0] ]: loss_layer = wj_bce_Loss(kernel_size=11, sigma=2,weights=weights, as_loss=True) loss = loss_layer(preds,targets) print(weights,' loss: ',loss) def test_multify_output(): pred1=torch.rand(8,2,100,100) pred2=torch.rand(8,5,100,100) target1 =torch.randint(0,2,(8,100,100)) target2 =torch.randint(0,5,(8,100,100)) loss_layer = thFloater(weights=[0.5,0.5]) loss = loss_layer([pred1,pred2],[target1,target2]) print(loss) if __name__=='__main__': #main() #kk = create_gaussian_kernel( kernel_size=11, sigma=2) #print(kk.numpy().shape) #plt.figure(0);plt.imshow(kk[0,0]);plt.show() #test_loss_moule() test_multify_output()