高速公路违停检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

detail_loss.py 5.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import torch
  2. from torch import nn
  3. from torch.nn import functional as F
  4. import cv2
  5. import numpy as np
  6. import json
  7. def dice_loss_func(input, target):
  8. smooth = 1.
  9. n = input.size(0)
  10. iflat = input.view(n, -1)
  11. tflat = target.view(n, -1)
  12. intersection = (iflat * tflat).sum(1)
  13. loss = 1 - ((2. * intersection + smooth) /
  14. (iflat.sum(1) + tflat.sum(1) + smooth))
  15. return loss.mean()
  16. def get_one_hot(label, N):
  17. size = list(label.size())
  18. label = label.view(-1) # reshape 为向量
  19. ones = torch.sparse.torch.eye(N).cuda()
  20. ones = ones.index_select(0, label.long()) # 用上面的办法转为换one hot
  21. size.append(N) # 把类别输目添到size的尾后,准备reshape回原来的尺寸
  22. return ones.view(*size)
  23. def get_boundary(gtmasks):
  24. laplacian_kernel = torch.tensor(
  25. [-1, -1, -1, -1, 8, -1, -1, -1, -1],
  26. dtype=torch.float32, device=gtmasks.device).reshape(1, 1, 3, 3).requires_grad_(False)
  27. # boundary_logits = boundary_logits.unsqueeze(1)
  28. boundary_targets = F.conv2d(gtmasks.unsqueeze(1), laplacian_kernel, padding=1)
  29. boundary_targets = boundary_targets.clamp(min=0)
  30. boundary_targets[boundary_targets > 0.1] = 1
  31. boundary_targets[boundary_targets <= 0.1] = 0
  32. return boundary_targets
  33. class DetailAggregateLoss(nn.Module):
  34. def __init__(self, *args, **kwargs):
  35. super(DetailAggregateLoss, self).__init__()
  36. self.laplacian_kernel = torch.tensor(
  37. [-1, -1, -1, -1, 8, -1, -1, -1, -1],
  38. dtype=torch.float32).reshape(1, 1, 3, 3).requires_grad_(False).type(torch.cuda.FloatTensor)
  39. self.fuse_kernel = torch.nn.Parameter(torch.tensor([[6./10], [3./10], [1./10]],
  40. dtype=torch.float32).reshape(1, 3, 1, 1).type(torch.cuda.FloatTensor))
  41. def forward(self, boundary_logits, gtmasks):
  42. # boundary_logits = boundary_logits.unsqueeze(1)
  43. boundary_targets = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, padding=1)
  44. boundary_targets = boundary_targets.clamp(min=0)
  45. boundary_targets[boundary_targets > 0.1] = 1
  46. boundary_targets[boundary_targets <= 0.1] = 0
  47. boundary_targets_x2 = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, stride=2, padding=1)
  48. boundary_targets_x2 = boundary_targets_x2.clamp(min=0)
  49. boundary_targets_x4 = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, stride=4, padding=1)
  50. boundary_targets_x4 = boundary_targets_x4.clamp(min=0)
  51. boundary_targets_x8 = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, stride=8, padding=1)
  52. boundary_targets_x8 = boundary_targets_x8.clamp(min=0)
  53. boundary_targets_x8_up = F.interpolate(boundary_targets_x8, boundary_targets.shape[2:], mode='nearest')
  54. boundary_targets_x4_up = F.interpolate(boundary_targets_x4, boundary_targets.shape[2:], mode='nearest')
  55. boundary_targets_x2_up = F.interpolate(boundary_targets_x2, boundary_targets.shape[2:], mode='nearest')
  56. boundary_targets_x2_up[boundary_targets_x2_up > 0.1] = 1
  57. boundary_targets_x2_up[boundary_targets_x2_up <= 0.1] = 0
  58. boundary_targets_x4_up[boundary_targets_x4_up > 0.1] = 1
  59. boundary_targets_x4_up[boundary_targets_x4_up <= 0.1] = 0
  60. boundary_targets_x8_up[boundary_targets_x8_up > 0.1] = 1
  61. boundary_targets_x8_up[boundary_targets_x8_up <= 0.1] = 0
  62. boudary_targets_pyramids = torch.stack((boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up), dim=1)
  63. boudary_targets_pyramids = boudary_targets_pyramids.squeeze(2)
  64. boudary_targets_pyramid = F.conv2d(boudary_targets_pyramids, self.fuse_kernel)
  65. boudary_targets_pyramid[boudary_targets_pyramid > 0.1] = 1
  66. boudary_targets_pyramid[boudary_targets_pyramid <= 0.1] = 0
  67. if boundary_logits.shape[-1] != boundary_targets.shape[-1]:
  68. boundary_logits = F.interpolate(
  69. boundary_logits, boundary_targets.shape[2:], mode='bilinear', align_corners=True)
  70. bce_loss = F.binary_cross_entropy_with_logits(boundary_logits, boudary_targets_pyramid)
  71. dice_loss = dice_loss_func(torch.sigmoid(boundary_logits), boudary_targets_pyramid)
  72. return bce_loss, dice_loss
  73. def get_params(self):
  74. wd_params, nowd_params = [], []
  75. for name, module in self.named_modules():
  76. nowd_params += list(module.parameters())
  77. return nowd_params
  78. if __name__ == '__main__':
  79. torch.manual_seed(15)
  80. with open('../cityscapes_info.json', 'r') as fr:
  81. labels_info = json.load(fr)
  82. lb_map = {el['id']: el['trainId'] for el in labels_info}
  83. img_path = 'data/gtFine/val/frankfurt/frankfurt_000001_037705_gtFine_labelIds.png'
  84. img = cv2.imread(img_path, 0)
  85. label = np.zeros(img.shape, np.uint8)
  86. for k, v in lb_map.items():
  87. label[img == k] = v
  88. img_tensor = torch.from_numpy(label).cuda()
  89. img_tensor = torch.unsqueeze(img_tensor, 0).type(torch.cuda.FloatTensor)
  90. detailAggregateLoss = DetailAggregateLoss()
  91. for param in detailAggregateLoss.parameters():
  92. print(param)
  93. bce_loss, dice_loss = detailAggregateLoss(torch.unsqueeze(img_tensor, 0), img_tensor)
  94. print(bce_loss, dice_loss)