STDC-th/loss/detail_loss.py

128 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from torch import nn
from torch.nn import functional as F
import cv2
import numpy as np
import json
def dice_loss_func(input, target):
smooth = 1.
n = input.size(0)
iflat = input.view(n, -1)
tflat = target.view(n, -1)
intersection = (iflat * tflat).sum(1)
loss = 1 - ((2. * intersection + smooth) /
(iflat.sum(1) + tflat.sum(1) + smooth))
return loss.mean()
def get_one_hot(label, N):
size = list(label.size())
label = label.view(-1) # reshape 为向量
ones = torch.sparse.torch.eye(N).cuda()
ones = ones.index_select(0, label.long()) # 用上面的办法转为换one hot
size.append(N) # 把类别输目添到size的尾后准备reshape回原来的尺寸
return ones.view(*size)
def get_boundary(gtmasks):
laplacian_kernel = torch.tensor(
[-1, -1, -1, -1, 8, -1, -1, -1, -1],
dtype=torch.float32, device=gtmasks.device).reshape(1, 1, 3, 3).requires_grad_(False)
# boundary_logits = boundary_logits.unsqueeze(1)
boundary_targets = F.conv2d(gtmasks.unsqueeze(1), laplacian_kernel, padding=1)
boundary_targets = boundary_targets.clamp(min=0)
boundary_targets[boundary_targets > 0.1] = 1
boundary_targets[boundary_targets <= 0.1] = 0
return boundary_targets
class DetailAggregateLoss(nn.Module):
def __init__(self, *args, **kwargs):
super(DetailAggregateLoss, self).__init__()
self.laplacian_kernel = torch.tensor(
[-1, -1, -1, -1, 8, -1, -1, -1, -1],
dtype=torch.float32).reshape(1, 1, 3, 3).requires_grad_(False).type(torch.cuda.FloatTensor)
self.fuse_kernel = torch.nn.Parameter(torch.tensor([[6./10], [3./10], [1./10]],
dtype=torch.float32).reshape(1, 3, 1, 1).type(torch.cuda.FloatTensor))
def forward(self, boundary_logits, gtmasks):
# boundary_logits = boundary_logits.unsqueeze(1)
boundary_targets = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, padding=1)
boundary_targets = boundary_targets.clamp(min=0)
boundary_targets[boundary_targets > 0.1] = 1
boundary_targets[boundary_targets <= 0.1] = 0
boundary_targets_x2 = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, stride=2, padding=1)
boundary_targets_x2 = boundary_targets_x2.clamp(min=0)
boundary_targets_x4 = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, stride=4, padding=1)
boundary_targets_x4 = boundary_targets_x4.clamp(min=0)
boundary_targets_x8 = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, stride=8, padding=1)
boundary_targets_x8 = boundary_targets_x8.clamp(min=0)
boundary_targets_x8_up = F.interpolate(boundary_targets_x8, boundary_targets.shape[2:], mode='nearest')
boundary_targets_x4_up = F.interpolate(boundary_targets_x4, boundary_targets.shape[2:], mode='nearest')
boundary_targets_x2_up = F.interpolate(boundary_targets_x2, boundary_targets.shape[2:], mode='nearest')
boundary_targets_x2_up[boundary_targets_x2_up > 0.1] = 1
boundary_targets_x2_up[boundary_targets_x2_up <= 0.1] = 0
boundary_targets_x4_up[boundary_targets_x4_up > 0.1] = 1
boundary_targets_x4_up[boundary_targets_x4_up <= 0.1] = 0
boundary_targets_x8_up[boundary_targets_x8_up > 0.1] = 1
boundary_targets_x8_up[boundary_targets_x8_up <= 0.1] = 0
boudary_targets_pyramids = torch.stack((boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up), dim=1)
boudary_targets_pyramids = boudary_targets_pyramids.squeeze(2)
boudary_targets_pyramid = F.conv2d(boudary_targets_pyramids, self.fuse_kernel)
boudary_targets_pyramid[boudary_targets_pyramid > 0.1] = 1
boudary_targets_pyramid[boudary_targets_pyramid <= 0.1] = 0
if boundary_logits.shape[-1] != boundary_targets.shape[-1]:
boundary_logits = F.interpolate(
boundary_logits, boundary_targets.shape[2:], mode='bilinear', align_corners=True)
bce_loss = F.binary_cross_entropy_with_logits(boundary_logits, boudary_targets_pyramid)
dice_loss = dice_loss_func(torch.sigmoid(boundary_logits), boudary_targets_pyramid)
return bce_loss, dice_loss
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
nowd_params += list(module.parameters())
return nowd_params
if __name__ == '__main__':
torch.manual_seed(15)
with open('../cityscapes_info.json', 'r') as fr:
labels_info = json.load(fr)
lb_map = {el['id']: el['trainId'] for el in labels_info}
img_path = 'data/gtFine/val/frankfurt/frankfurt_000001_037705_gtFine_labelIds.png'
img = cv2.imread(img_path, 0)
label = np.zeros(img.shape, np.uint8)
for k, v in lb_map.items():
label[img == k] = v
img_tensor = torch.from_numpy(label).cuda()
img_tensor = torch.unsqueeze(img_tensor, 0).type(torch.cuda.FloatTensor)
detailAggregateLoss = DetailAggregateLoss()
for param in detailAggregateLoss.parameters():
print(param)
bce_loss, dice_loss = detailAggregateLoss(torch.unsqueeze(img_tensor, 0), img_tensor)
print(bce_loss, dice_loss)