用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

164 lines
5.8KB

  1. """Point-wise Spatial Attention Network"""
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from core.nn import _ConvBNReLU
  6. from core.models.segbase import SegBaseModel
  7. from core.models.fcn import _FCNHead
  8. __all__ = ['PSANet', 'get_psanet', 'get_psanet_resnet50_voc', 'get_psanet_resnet101_voc',
  9. 'get_psanet_resnet152_voc', 'get_psanet_resnet50_citys', 'get_psanet_resnet101_citys',
  10. 'get_psanet_resnet152_citys']
  11. class PSANet(SegBaseModel):
  12. r"""PSANet
  13. Parameters
  14. ----------
  15. nclass : int
  16. Number of categories for the training dataset.
  17. backbone : string
  18. Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
  19. 'resnet101' or 'resnet152').
  20. norm_layer : object
  21. Normalization layer used in backbone network (default: :class:`nn.BatchNorm`;
  22. for Synchronized Cross-GPU BachNormalization).
  23. aux : bool
  24. Auxiliary loss.
  25. Reference:
  26. Hengshuang Zhao, et al. "PSANet: Point-wise Spatial Attention Network for Scene Parsing."
  27. ECCV-2018.
  28. """
  29. def __init__(self, nclass, backbone='resnet50', aux=False, pretrained_base=True, **kwargs):
  30. super(PSANet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
  31. self.head = _PSAHead(nclass, **kwargs)
  32. if aux:
  33. self.auxlayer = _FCNHead(1024, nclass, **kwargs)
  34. self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
  35. def forward(self, x):
  36. size = x.size()[2:]
  37. _, _, c3, c4 = self.base_forward(x)
  38. outputs = list()
  39. x = self.head(c4)
  40. x = F.interpolate(x, size, mode='bilinear', align_corners=True)
  41. outputs.append(x)
  42. if self.aux:
  43. auxout = self.auxlayer(c3)
  44. auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
  45. outputs.append(auxout)
  46. #return tuple(outputs)
  47. return outputs[0]
  48. class _PSAHead(nn.Module):
  49. def __init__(self, nclass, norm_layer=nn.BatchNorm2d, **kwargs):
  50. super(_PSAHead, self).__init__()
  51. # psa_out_channels = crop_size // 8 ** 2
  52. self.psa = _PointwiseSpatialAttention(2048, 3600, norm_layer)
  53. self.conv_post = _ConvBNReLU(1024, 2048, 1, norm_layer=norm_layer)
  54. self.project = nn.Sequential(
  55. _ConvBNReLU(4096, 512, 3, padding=1, norm_layer=norm_layer),
  56. nn.Dropout2d(0.1, False),
  57. nn.Conv2d(512, nclass, 1))
  58. def forward(self, x):
  59. global_feature = self.psa(x)
  60. out = self.conv_post(global_feature)
  61. out = torch.cat([x, out], dim=1)
  62. out = self.project(out)
  63. return out
  64. class _PointwiseSpatialAttention(nn.Module):#
  65. def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d, **kwargs):
  66. super(_PointwiseSpatialAttention, self).__init__()
  67. reduced_channels = 512
  68. self.collect_attention = _AttentionGeneration(in_channels, reduced_channels, out_channels, norm_layer)
  69. self.distribute_attention = _AttentionGeneration(in_channels, reduced_channels, out_channels, norm_layer)
  70. def forward(self, x):
  71. collect_fm = self.collect_attention(x)
  72. distribute_fm = self.distribute_attention(x)
  73. psa_fm = torch.cat([collect_fm, distribute_fm], dim=1)
  74. return psa_fm
  75. class _AttentionGeneration(nn.Module):#-->Z:(n,C2,H,W),不是原文over-completed的做法。
  76. def __init__(self, in_channels, reduced_channels, out_channels, norm_layer, **kwargs):
  77. super(_AttentionGeneration, self).__init__()
  78. self.conv_reduce = _ConvBNReLU(in_channels, reduced_channels, 1, norm_layer=norm_layer)
  79. self.attention = nn.Sequential(
  80. _ConvBNReLU(reduced_channels, reduced_channels, 1, norm_layer=norm_layer),
  81. nn.Conv2d(reduced_channels, out_channels, 1, bias=False))
  82. self.reduced_channels = reduced_channels
  83. def forward(self, x):
  84. reduce_x = self.conv_reduce(x)
  85. attention = self.attention(reduce_x)
  86. n, c, h, w = attention.size()#c=out_channels=3600,
  87. attention = attention.view(n, c, -1)#(n,3600,H*W)
  88. reduce_x = reduce_x.view(n, self.reduced_channels, -1)#(n,512,H*W)
  89. print(reduce_x.shape,attention.shape)
  90. fm = torch.bmm(reduce_x, torch.softmax(attention, dim=1))
  91. fm = fm.view(n, self.reduced_channels, h, w)#(n,512,60,60)
  92. return fm
  93. def get_psanet(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~/.torch/models',
  94. pretrained_base=False, **kwargs):
  95. acronyms = {
  96. 'pascal_voc': 'pascal_voc',
  97. 'pascal_aug': 'pascal_aug',
  98. 'ade20k': 'ade',
  99. 'coco': 'coco',
  100. 'citys': 'citys',
  101. }
  102. from core.data.dataloader import datasets
  103. model = PSANet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
  104. if pretrained:
  105. from .model_store import get_model_file
  106. device = torch.device(kwargs['local_rank'])
  107. model.load_state_dict(torch.load(get_model_file('deeplabv3_%s_%s' % (backbone, acronyms[dataset]), root=root),
  108. map_location=device))
  109. return model
  110. def get_psanet_resnet50_voc(**kwargs):
  111. return get_psanet('pascal_voc', 'resnet50', **kwargs)
  112. def get_psanet_resnet101_voc(**kwargs):
  113. return get_psanet('pascal_voc', 'resnet101', **kwargs)
  114. def get_psanet_resnet152_voc(**kwargs):
  115. return get_psanet('pascal_voc', 'resnet152', **kwargs)
  116. def get_psanet_resnet50_citys(**kwargs):
  117. return get_psanet('citys', 'resnet50', **kwargs)
  118. def get_psanet_resnet101_citys(**kwargs):
  119. return get_psanet('citys', 'resnet101', **kwargs)
  120. def get_psanet_resnet152_citys(**kwargs):
  121. return get_psanet('citys', 'resnet152', **kwargs)
  122. if __name__ == '__main__':
  123. model = get_psanet_resnet50_voc()
  124. img = torch.randn(1, 3, 480, 480)
  125. output = model(img)