用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

232 satır
8.3KB

  1. """Dual Attention Network"""
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from core.models.segbase import SegBaseModel
  6. __all__ = ['DANet', 'get_danet', 'get_danet_resnet50_citys',
  7. 'get_danet_resnet101_citys', 'get_danet_resnet152_citys']
  8. class DANet(SegBaseModel):
  9. r"""Pyramid Scene Parsing Network
  10. Parameters
  11. ----------
  12. nclass : int
  13. Number of categories for the training dataset.
  14. backbone : string
  15. Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
  16. 'resnet101' or 'resnet152').
  17. norm_layer : object
  18. Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
  19. for Synchronized Cross-GPU BachNormalization).
  20. aux : bool
  21. Auxiliary loss.
  22. Reference:
  23. Jun Fu, Jing Liu, Haijie Tian, Yong Li, Yongjun Bao, Zhiwei Fang,and Hanqing Lu.
  24. "Dual Attention Network for Scene Segmentation." *CVPR*, 2019
  25. """
  26. def __init__(self, nclass, backbone='resnet50', aux=True, pretrained_base=True, **kwargs):
  27. super(DANet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
  28. self.head = _DAHead(2048, nclass, aux, **kwargs)
  29. self.__setattr__('exclusive', ['head'])
  30. def forward(self, x):
  31. size = x.size()[2:]
  32. _, _, c3, c4 = self.base_forward(x)
  33. outputs = []
  34. x = self.head(c4)
  35. x0 = F.interpolate(x[0], size, mode='bilinear', align_corners=True)
  36. outputs.append(x0)
  37. if self.aux:
  38. x1 = F.interpolate(x[1], size, mode='bilinear', align_corners=True)
  39. x2 = F.interpolate(x[2], size, mode='bilinear', align_corners=True)
  40. outputs.append(x1)
  41. outputs.append(x2)
  42. #return outputs
  43. return outputs[0]
  44. class _PositionAttentionModule(nn.Module):
  45. """ Position attention module"""
  46. def __init__(self, in_channels, **kwargs):
  47. super(_PositionAttentionModule, self).__init__()
  48. self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1)
  49. self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1)
  50. self.conv_d = nn.Conv2d(in_channels, in_channels, 1)
  51. self.alpha = nn.Parameter(torch.zeros(1))
  52. self.softmax = nn.Softmax(dim=-1)
  53. def forward(self, x):
  54. batch_size, _, height, width = x.size()
  55. feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1)
  56. feat_c = self.conv_c(x).view(batch_size, -1, height * width)
  57. attention_s = self.softmax(torch.bmm(feat_b, feat_c))
  58. feat_d = self.conv_d(x).view(batch_size, -1, height * width)
  59. feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width)
  60. out = self.alpha * feat_e + x
  61. return out
  62. class _ChannelAttentionModule(nn.Module):
  63. """Channel attention module"""
  64. def __init__(self, **kwargs):
  65. super(_ChannelAttentionModule, self).__init__()
  66. self.beta = nn.Parameter(torch.zeros(1))
  67. self.softmax = nn.Softmax(dim=-1)
  68. def forward(self, x):
  69. batch_size, _, height, width = x.size()
  70. feat_a = x.view(batch_size, -1, height * width)
  71. feat_a_transpose = x.view(batch_size, -1, height * width).permute(0, 2, 1)
  72. attention = torch.bmm(feat_a, feat_a_transpose)
  73. attention_new = torch.max(attention, dim=-1, keepdim=True)[0].expand_as(attention) - attention
  74. attention = self.softmax(attention_new)
  75. feat_e = torch.bmm(attention, feat_a).view(batch_size, -1, height, width)
  76. out = self.beta * feat_e + x
  77. return out
  78. class _DAHead(nn.Module):
  79. def __init__(self, in_channels, nclass, aux=True, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
  80. super(_DAHead, self).__init__()
  81. self.aux = aux
  82. inter_channels = in_channels // 4
  83. self.conv_p1 = nn.Sequential(
  84. nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
  85. norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
  86. nn.ReLU(True)
  87. )
  88. self.conv_c1 = nn.Sequential(
  89. nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
  90. norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
  91. nn.ReLU(True)
  92. )
  93. self.pam = _PositionAttentionModule(inter_channels, **kwargs)
  94. self.cam = _ChannelAttentionModule(**kwargs)
  95. self.conv_p2 = nn.Sequential(
  96. nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
  97. norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
  98. nn.ReLU(True)
  99. )
  100. self.conv_c2 = nn.Sequential(
  101. nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
  102. norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
  103. nn.ReLU(True)
  104. )
  105. self.out = nn.Sequential(
  106. nn.Dropout(0.1),
  107. nn.Conv2d(inter_channels, nclass, 1)
  108. )
  109. if aux:
  110. self.conv_p3 = nn.Sequential(
  111. nn.Dropout(0.1),
  112. nn.Conv2d(inter_channels, nclass, 1)
  113. )
  114. self.conv_c3 = nn.Sequential(
  115. nn.Dropout(0.1),
  116. nn.Conv2d(inter_channels, nclass, 1)
  117. )
  118. def forward(self, x):
  119. feat_p = self.conv_p1(x)
  120. feat_p = self.pam(feat_p)
  121. feat_p = self.conv_p2(feat_p)
  122. feat_c = self.conv_c1(x)
  123. feat_c = self.cam(feat_c)
  124. feat_c = self.conv_c2(feat_c)
  125. feat_fusion = feat_p + feat_c
  126. outputs = []
  127. fusion_out = self.out(feat_fusion)
  128. outputs.append(fusion_out)
  129. if self.aux:
  130. p_out = self.conv_p3(feat_p)
  131. c_out = self.conv_c3(feat_c)
  132. outputs.append(p_out)
  133. outputs.append(c_out)
  134. return tuple(outputs)
  135. def get_danet(dataset='citys', backbone='resnet50', pretrained=False,
  136. root='~/.torch/models', pretrained_base=True, **kwargs):
  137. r"""Dual Attention Network
  138. Parameters
  139. ----------
  140. dataset : str, default pascal_voc
  141. The dataset that model pretrained on. (pascal_voc, ade20k)
  142. pretrained : bool or str
  143. Boolean value controls whether to load the default pretrained weights for model.
  144. String value represents the hashtag for a certain version of pretrained weights.
  145. root : str, default '~/.torch/models'
  146. Location for keeping the model parameters.
  147. pretrained_base : bool or str, default True
  148. This will load pretrained backbone network, that was trained on ImageNet.
  149. Examples
  150. --------
  151. >>> model = get_danet(dataset='pascal_voc', backbone='resnet50', pretrained=False)
  152. >>> print(model)
  153. """
  154. acronyms = {
  155. 'pascal_voc': 'pascal_voc',
  156. 'pascal_aug': 'pascal_aug',
  157. 'ade20k': 'ade',
  158. 'coco': 'coco',
  159. 'citys': 'citys',
  160. }
  161. from ..data.dataloader import datasets
  162. model = DANet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
  163. if pretrained:
  164. from .model_store import get_model_file
  165. device = torch.device(kwargs['local_rank'])
  166. model.load_state_dict(torch.load(get_model_file('danet_%s_%s' % (backbone, acronyms[dataset]), root=root),
  167. map_location=device))
  168. return model
  169. def get_danet_resnet50_citys(**kwargs):
  170. return get_danet('citys', 'resnet50', **kwargs)
  171. def get_danet_resnet101_citys(**kwargs):
  172. return get_danet('citys', 'resnet101', **kwargs)
  173. def get_danet_resnet152_citys(**kwargs):
  174. return get_danet('citys', 'resnet152', **kwargs)
  175. if __name__ == '__main__':
  176. # img = torch.randn(2, 3, 480, 480)
  177. # model = get_danet_resnet50_citys()
  178. # outputs = model(img)
  179. input = torch.rand(2, 3,512,512)
  180. model = DANet(4, pretrained_base=False)
  181. # target = torch.zeros(4, 512, 512).cuda()
  182. # model.eval()
  183. # print(model)
  184. loss = model(input)
  185. print(loss, loss.shape)
  186. # from torchsummary import summary
  187. #
  188. # summary(model, (3, 224, 224)) # 打印表格,按顺序输出每层的输出形状和参数
  189. import torch
  190. from thop import profile
  191. from torchsummary import summary
  192. flop, params = profile(model, input_size=(1, 3, 512, 512))
  193. print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop / 1e9, params / 1e6))