128 lines
5.4 KiB
Python
128 lines
5.4 KiB
Python
|
||
import torch
|
||
from torch import nn
|
||
from torch.nn import functional as F
|
||
import cv2
|
||
import numpy as np
|
||
import json
|
||
|
||
def dice_loss_func(input, target):
|
||
smooth = 1.
|
||
n = input.size(0)
|
||
iflat = input.view(n, -1)
|
||
tflat = target.view(n, -1)
|
||
intersection = (iflat * tflat).sum(1)
|
||
loss = 1 - ((2. * intersection + smooth) /
|
||
(iflat.sum(1) + tflat.sum(1) + smooth))
|
||
return loss.mean()
|
||
|
||
def get_one_hot(label, N):
|
||
size = list(label.size())
|
||
label = label.view(-1) # reshape 为向量
|
||
ones = torch.sparse.torch.eye(N).cuda()
|
||
ones = ones.index_select(0, label.long()) # 用上面的办法转为换one hot
|
||
size.append(N) # 把类别输目添到size的尾后,准备reshape回原来的尺寸
|
||
return ones.view(*size)
|
||
|
||
def get_boundary(gtmasks):
|
||
|
||
laplacian_kernel = torch.tensor(
|
||
[-1, -1, -1, -1, 8, -1, -1, -1, -1],
|
||
dtype=torch.float32, device=gtmasks.device).reshape(1, 1, 3, 3).requires_grad_(False)
|
||
# boundary_logits = boundary_logits.unsqueeze(1)
|
||
boundary_targets = F.conv2d(gtmasks.unsqueeze(1), laplacian_kernel, padding=1)
|
||
boundary_targets = boundary_targets.clamp(min=0)
|
||
boundary_targets[boundary_targets > 0.1] = 1
|
||
boundary_targets[boundary_targets <= 0.1] = 0
|
||
return boundary_targets
|
||
|
||
|
||
class DetailAggregateLoss(nn.Module):
|
||
def __init__(self, *args, **kwargs):
|
||
super(DetailAggregateLoss, self).__init__()
|
||
|
||
self.laplacian_kernel = torch.tensor(
|
||
[-1, -1, -1, -1, 8, -1, -1, -1, -1],
|
||
dtype=torch.float32).reshape(1, 1, 3, 3).requires_grad_(False).type(torch.cuda.FloatTensor)
|
||
|
||
self.fuse_kernel = torch.nn.Parameter(torch.tensor([[6./10], [3./10], [1./10]],
|
||
dtype=torch.float32).reshape(1, 3, 1, 1).type(torch.cuda.FloatTensor))
|
||
|
||
def forward(self, boundary_logits, gtmasks):
|
||
|
||
# boundary_logits = boundary_logits.unsqueeze(1)
|
||
boundary_targets = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, padding=1)
|
||
boundary_targets = boundary_targets.clamp(min=0)
|
||
boundary_targets[boundary_targets > 0.1] = 1
|
||
boundary_targets[boundary_targets <= 0.1] = 0
|
||
|
||
boundary_targets_x2 = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, stride=2, padding=1)
|
||
boundary_targets_x2 = boundary_targets_x2.clamp(min=0)
|
||
|
||
boundary_targets_x4 = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, stride=4, padding=1)
|
||
boundary_targets_x4 = boundary_targets_x4.clamp(min=0)
|
||
|
||
boundary_targets_x8 = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, stride=8, padding=1)
|
||
boundary_targets_x8 = boundary_targets_x8.clamp(min=0)
|
||
|
||
boundary_targets_x8_up = F.interpolate(boundary_targets_x8, boundary_targets.shape[2:], mode='nearest')
|
||
boundary_targets_x4_up = F.interpolate(boundary_targets_x4, boundary_targets.shape[2:], mode='nearest')
|
||
boundary_targets_x2_up = F.interpolate(boundary_targets_x2, boundary_targets.shape[2:], mode='nearest')
|
||
|
||
boundary_targets_x2_up[boundary_targets_x2_up > 0.1] = 1
|
||
boundary_targets_x2_up[boundary_targets_x2_up <= 0.1] = 0
|
||
|
||
|
||
boundary_targets_x4_up[boundary_targets_x4_up > 0.1] = 1
|
||
boundary_targets_x4_up[boundary_targets_x4_up <= 0.1] = 0
|
||
|
||
|
||
boundary_targets_x8_up[boundary_targets_x8_up > 0.1] = 1
|
||
boundary_targets_x8_up[boundary_targets_x8_up <= 0.1] = 0
|
||
|
||
boudary_targets_pyramids = torch.stack((boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up), dim=1)
|
||
|
||
boudary_targets_pyramids = boudary_targets_pyramids.squeeze(2)
|
||
boudary_targets_pyramid = F.conv2d(boudary_targets_pyramids, self.fuse_kernel)
|
||
|
||
boudary_targets_pyramid[boudary_targets_pyramid > 0.1] = 1
|
||
boudary_targets_pyramid[boudary_targets_pyramid <= 0.1] = 0
|
||
|
||
|
||
if boundary_logits.shape[-1] != boundary_targets.shape[-1]:
|
||
boundary_logits = F.interpolate(
|
||
boundary_logits, boundary_targets.shape[2:], mode='bilinear', align_corners=True)
|
||
|
||
bce_loss = F.binary_cross_entropy_with_logits(boundary_logits, boudary_targets_pyramid)
|
||
dice_loss = dice_loss_func(torch.sigmoid(boundary_logits), boudary_targets_pyramid)
|
||
return bce_loss, dice_loss
|
||
|
||
def get_params(self):
|
||
wd_params, nowd_params = [], []
|
||
for name, module in self.named_modules():
|
||
nowd_params += list(module.parameters())
|
||
return nowd_params
|
||
|
||
if __name__ == '__main__':
|
||
torch.manual_seed(15)
|
||
with open('../cityscapes_info.json', 'r') as fr:
|
||
labels_info = json.load(fr)
|
||
lb_map = {el['id']: el['trainId'] for el in labels_info}
|
||
|
||
img_path = 'data/gtFine/val/frankfurt/frankfurt_000001_037705_gtFine_labelIds.png'
|
||
img = cv2.imread(img_path, 0)
|
||
|
||
label = np.zeros(img.shape, np.uint8)
|
||
for k, v in lb_map.items():
|
||
label[img == k] = v
|
||
|
||
img_tensor = torch.from_numpy(label).cuda()
|
||
img_tensor = torch.unsqueeze(img_tensor, 0).type(torch.cuda.FloatTensor)
|
||
|
||
|
||
detailAggregateLoss = DetailAggregateLoss()
|
||
for param in detailAggregateLoss.parameters():
|
||
print(param)
|
||
|
||
bce_loss, dice_loss = detailAggregateLoss(torch.unsqueeze(img_tensor, 0), img_tensor)
|
||
print(bce_loss, dice_loss) |