|
-
- import torch
- from torch import nn
- from torch.nn import functional as F
- import cv2
- import numpy as np
- import json
-
- def dice_loss_func(input, target):
- smooth = 1.
- n = input.size(0)
- iflat = input.view(n, -1)
- tflat = target.view(n, -1)
- intersection = (iflat * tflat).sum(1)
- loss = 1 - ((2. * intersection + smooth) /
- (iflat.sum(1) + tflat.sum(1) + smooth))
- return loss.mean()
-
- def get_one_hot(label, N):
- size = list(label.size())
- label = label.view(-1) # reshape 为向量
- ones = torch.sparse.torch.eye(N).cuda()
- ones = ones.index_select(0, label.long()) # 用上面的办法转为换one hot
- size.append(N) # 把类别输目添到size的尾后,准备reshape回原来的尺寸
- return ones.view(*size)
-
- def get_boundary(gtmasks):
-
- laplacian_kernel = torch.tensor(
- [-1, -1, -1, -1, 8, -1, -1, -1, -1],
- dtype=torch.float32, device=gtmasks.device).reshape(1, 1, 3, 3).requires_grad_(False)
- # boundary_logits = boundary_logits.unsqueeze(1)
- boundary_targets = F.conv2d(gtmasks.unsqueeze(1), laplacian_kernel, padding=1)
- boundary_targets = boundary_targets.clamp(min=0)
- boundary_targets[boundary_targets > 0.1] = 1
- boundary_targets[boundary_targets <= 0.1] = 0
- return boundary_targets
-
-
- class DetailAggregateLoss(nn.Module):
- def __init__(self, *args, **kwargs):
- super(DetailAggregateLoss, self).__init__()
-
- self.laplacian_kernel = torch.tensor(
- [-1, -1, -1, -1, 8, -1, -1, -1, -1],
- dtype=torch.float32).reshape(1, 1, 3, 3).requires_grad_(False).type(torch.cuda.FloatTensor)
-
- self.fuse_kernel = torch.nn.Parameter(torch.tensor([[6./10], [3./10], [1./10]],
- dtype=torch.float32).reshape(1, 3, 1, 1).type(torch.cuda.FloatTensor))
-
- def forward(self, boundary_logits, gtmasks):
-
- # boundary_logits = boundary_logits.unsqueeze(1)
- boundary_targets = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, padding=1)
- boundary_targets = boundary_targets.clamp(min=0)
- boundary_targets[boundary_targets > 0.1] = 1
- boundary_targets[boundary_targets <= 0.1] = 0
-
- boundary_targets_x2 = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, stride=2, padding=1)
- boundary_targets_x2 = boundary_targets_x2.clamp(min=0)
-
- boundary_targets_x4 = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, stride=4, padding=1)
- boundary_targets_x4 = boundary_targets_x4.clamp(min=0)
-
- boundary_targets_x8 = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, stride=8, padding=1)
- boundary_targets_x8 = boundary_targets_x8.clamp(min=0)
-
- boundary_targets_x8_up = F.interpolate(boundary_targets_x8, boundary_targets.shape[2:], mode='nearest')
- boundary_targets_x4_up = F.interpolate(boundary_targets_x4, boundary_targets.shape[2:], mode='nearest')
- boundary_targets_x2_up = F.interpolate(boundary_targets_x2, boundary_targets.shape[2:], mode='nearest')
-
- boundary_targets_x2_up[boundary_targets_x2_up > 0.1] = 1
- boundary_targets_x2_up[boundary_targets_x2_up <= 0.1] = 0
-
-
- boundary_targets_x4_up[boundary_targets_x4_up > 0.1] = 1
- boundary_targets_x4_up[boundary_targets_x4_up <= 0.1] = 0
-
-
- boundary_targets_x8_up[boundary_targets_x8_up > 0.1] = 1
- boundary_targets_x8_up[boundary_targets_x8_up <= 0.1] = 0
-
- boudary_targets_pyramids = torch.stack((boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up), dim=1)
-
- boudary_targets_pyramids = boudary_targets_pyramids.squeeze(2)
- boudary_targets_pyramid = F.conv2d(boudary_targets_pyramids, self.fuse_kernel)
-
- boudary_targets_pyramid[boudary_targets_pyramid > 0.1] = 1
- boudary_targets_pyramid[boudary_targets_pyramid <= 0.1] = 0
-
-
- if boundary_logits.shape[-1] != boundary_targets.shape[-1]:
- boundary_logits = F.interpolate(
- boundary_logits, boundary_targets.shape[2:], mode='bilinear', align_corners=True)
-
- bce_loss = F.binary_cross_entropy_with_logits(boundary_logits, boudary_targets_pyramid)
- dice_loss = dice_loss_func(torch.sigmoid(boundary_logits), boudary_targets_pyramid)
- return bce_loss, dice_loss
-
- def get_params(self):
- wd_params, nowd_params = [], []
- for name, module in self.named_modules():
- nowd_params += list(module.parameters())
- return nowd_params
-
- if __name__ == '__main__':
- torch.manual_seed(15)
- with open('../cityscapes_info.json', 'r') as fr:
- labels_info = json.load(fr)
- lb_map = {el['id']: el['trainId'] for el in labels_info}
-
- img_path = 'data/gtFine/val/frankfurt/frankfurt_000001_037705_gtFine_labelIds.png'
- img = cv2.imread(img_path, 0)
-
- label = np.zeros(img.shape, np.uint8)
- for k, v in lb_map.items():
- label[img == k] = v
-
- img_tensor = torch.from_numpy(label).cuda()
- img_tensor = torch.unsqueeze(img_tensor, 0).type(torch.cuda.FloatTensor)
-
-
- detailAggregateLoss = DetailAggregateLoss()
- for param in detailAggregateLoss.parameters():
- print(param)
-
- bce_loss, dice_loss = detailAggregateLoss(torch.unsqueeze(img_tensor, 0), img_tensor)
- print(bce_loss, dice_loss)
|