用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

cgnet.py 7.8KB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. """Context Guided Network for Semantic Segmentation"""
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from core.nn import _ConvBNPReLU, _BNPReLU
  6. __all__ = ['CGNet', 'get_cgnet', 'get_cgnet_citys']
  7. class CGNet(nn.Module):
  8. r"""CGNet
  9. Parameters
  10. ----------
  11. nclass : int
  12. Number of categories for the training dataset.
  13. norm_layer : object
  14. Normalization layer used in backbone network (default: :class:`nn.BatchNorm`;
  15. for Synchronized Cross-GPU BachNormalization).
  16. aux : bool
  17. Auxiliary loss.
  18. Reference:
  19. Tianyi Wu, et al. "CGNet: A Light-weight Context Guided Network for Semantic Segmentation."
  20. arXiv preprint arXiv:1811.08201 (2018).
  21. """
  22. def __init__(self, nclass, backbone='', aux=False, jpu=False, pretrained_base=True, M=3, N=21, **kwargs):
  23. super(CGNet, self).__init__()
  24. # stage 1
  25. self.stage1_0 = _ConvBNPReLU(3, 32, 3, 2, 1, **kwargs)
  26. self.stage1_1 = _ConvBNPReLU(32, 32, 3, 1, 1, **kwargs)
  27. self.stage1_2 = _ConvBNPReLU(32, 32, 3, 1, 1, **kwargs)
  28. self.sample1 = _InputInjection(1)
  29. self.sample2 = _InputInjection(2)
  30. self.bn_prelu1 = _BNPReLU(32 + 3, **kwargs)
  31. # stage 2
  32. self.stage2_0 = ContextGuidedBlock(32 + 3, 64, dilation=2, reduction=8, down=True, residual=False, **kwargs)
  33. self.stage2 = nn.ModuleList()
  34. for i in range(0, M - 1):
  35. self.stage2.append(ContextGuidedBlock(64, 64, dilation=2, reduction=8, **kwargs))
  36. self.bn_prelu2 = _BNPReLU(128 + 3, **kwargs)
  37. # stage 3
  38. self.stage3_0 = ContextGuidedBlock(128 + 3, 128, dilation=4, reduction=16, down=True, residual=False, **kwargs)
  39. self.stage3 = nn.ModuleList()
  40. for i in range(0, N - 1):
  41. self.stage3.append(ContextGuidedBlock(128, 128, dilation=4, reduction=16, **kwargs))
  42. self.bn_prelu3 = _BNPReLU(256, **kwargs)
  43. self.head = nn.Sequential(
  44. nn.Dropout2d(0.1, False),
  45. nn.Conv2d(256, nclass, 1))
  46. self.__setattr__('exclusive', ['stage1_0', 'stage1_1', 'stage1_2', 'sample1', 'sample2',
  47. 'bn_prelu1', 'stage2_0', 'stage2', 'bn_prelu2', 'stage3_0',
  48. 'stage3', 'bn_prelu3', 'head'])
  49. def forward(self, x):
  50. size = x.size()[2:]
  51. # stage1
  52. out0 = self.stage1_0(x)
  53. out0 = self.stage1_1(out0)
  54. out0 = self.stage1_2(out0)
  55. inp1 = self.sample1(x)
  56. inp2 = self.sample2(x)
  57. # stage 2
  58. out0_cat = self.bn_prelu1(torch.cat([out0, inp1], dim=1))
  59. out1_0 = self.stage2_0(out0_cat)
  60. for i, layer in enumerate(self.stage2):
  61. if i == 0:
  62. out1 = layer(out1_0)
  63. else:
  64. out1 = layer(out1)
  65. out1_cat = self.bn_prelu2(torch.cat([out1, out1_0, inp2], dim=1))
  66. # stage 3
  67. out2_0 = self.stage3_0(out1_cat)
  68. for i, layer in enumerate(self.stage3):
  69. if i == 0:
  70. out2 = layer(out2_0)
  71. else:
  72. out2 = layer(out2)
  73. out2_cat = self.bn_prelu3(torch.cat([out2_0, out2], dim=1))
  74. outputs = []
  75. out = self.head(out2_cat)
  76. out = F.interpolate(out, size, mode='bilinear', align_corners=True)
  77. outputs.append(out)
  78. #return tuple(outputs)
  79. return outputs[0]
  80. class _ChannelWiseConv(nn.Module):
  81. def __init__(self, in_channels, out_channels, dilation=1, **kwargs):
  82. super(_ChannelWiseConv, self).__init__()
  83. self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, dilation, dilation, groups=in_channels, bias=False)
  84. def forward(self, x):
  85. x = self.conv(x)
  86. return x
  87. class _FGlo(nn.Module):
  88. def __init__(self, in_channels, reduction=16, **kwargs):
  89. super(_FGlo, self).__init__()
  90. self.gap = nn.AdaptiveAvgPool2d(1)
  91. self.fc = nn.Sequential(
  92. nn.Linear(in_channels, in_channels // reduction),
  93. nn.ReLU(True),
  94. nn.Linear(in_channels // reduction, in_channels),
  95. nn.Sigmoid())
  96. def forward(self, x):
  97. n, c, _, _ = x.size()
  98. out = self.gap(x).view(n, c)
  99. out = self.fc(out).view(n, c, 1, 1)
  100. return x * out
  101. class _InputInjection(nn.Module):
  102. def __init__(self, ratio):
  103. super(_InputInjection, self).__init__()
  104. self.pool = nn.ModuleList()
  105. for i in range(0, ratio):
  106. self.pool.append(nn.AvgPool2d(3, 2, 1))
  107. def forward(self, x):
  108. for pool in self.pool:
  109. x = pool(x)
  110. return x
  111. class _ConcatInjection(nn.Module):
  112. def __init__(self, in_channels, norm_layer=nn.BatchNorm2d, **kwargs):
  113. super(_ConcatInjection, self).__init__()
  114. self.bn = norm_layer(in_channels)
  115. self.prelu = nn.PReLU(in_channels)
  116. def forward(self, x1, x2):
  117. out = torch.cat([x1, x2], dim=1)
  118. out = self.bn(out)
  119. out = self.prelu(out)
  120. return out
  121. class ContextGuidedBlock(nn.Module):
  122. def __init__(self, in_channels, out_channels, dilation=2, reduction=16, down=False,
  123. residual=True, norm_layer=nn.BatchNorm2d, **kwargs):
  124. super(ContextGuidedBlock, self).__init__()
  125. inter_channels = out_channels // 2 if not down else out_channels
  126. if down:
  127. self.conv = _ConvBNPReLU(in_channels, inter_channels, 3, 2, 1, norm_layer=norm_layer, **kwargs)
  128. self.reduce = nn.Conv2d(inter_channels * 2, out_channels, 1, bias=False)
  129. else:
  130. self.conv = _ConvBNPReLU(in_channels, inter_channels, 1, 1, 0, norm_layer=norm_layer, **kwargs)
  131. self.f_loc = _ChannelWiseConv(inter_channels, inter_channels, **kwargs)
  132. self.f_sur = _ChannelWiseConv(inter_channels, inter_channels, dilation, **kwargs)
  133. self.bn = norm_layer(inter_channels * 2)
  134. self.prelu = nn.PReLU(inter_channels * 2)
  135. self.f_glo = _FGlo(out_channels, reduction, **kwargs)
  136. self.down = down
  137. self.residual = residual
  138. def forward(self, x):
  139. out = self.conv(x)
  140. loc = self.f_loc(out)
  141. sur = self.f_sur(out)
  142. joi_feat = torch.cat([loc, sur], dim=1)
  143. joi_feat = self.prelu(self.bn(joi_feat))
  144. if self.down:
  145. joi_feat = self.reduce(joi_feat)
  146. out = self.f_glo(joi_feat)
  147. if self.residual:
  148. out = out + x
  149. return out
  150. def get_cgnet(dataset='citys', backbone='', pretrained=False, root='~/.torch/models', pretrained_base=True, **kwargs):
  151. acronyms = {
  152. 'pascal_voc': 'pascal_voc',
  153. 'pascal_aug': 'pascal_aug',
  154. 'ade20k': 'ade',
  155. 'coco': 'coco',
  156. 'citys': 'citys',
  157. }
  158. from core.data.dataloader import datasets
  159. model = CGNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
  160. if pretrained:
  161. from .model_store import get_model_file
  162. device = torch.device(kwargs['local_rank'])
  163. model.load_state_dict(torch.load(get_model_file('cgnet_%s' % (acronyms[dataset]), root=root),
  164. map_location=device))
  165. return model
  166. def get_cgnet_citys(**kwargs):
  167. return get_cgnet('citys', '', **kwargs)
  168. if __name__ == '__main__':
  169. # model = get_cgnet_citys()
  170. # print(model)
  171. input = torch.rand(2, 3, 224, 224)
  172. model = CGNet(4, pretrained_base=True)
  173. # target = torch.zeros(4, 512, 512).cuda()
  174. # model.eval()
  175. # print(model)
  176. loss = model(input)
  177. print(loss, loss.shape)
  178. # from torchsummary import summary
  179. #
  180. # summary(model, (3, 224, 224)) # 打印表格,按顺序输出每层的输出形状和参数
  181. import torch
  182. from thop import profile
  183. from torchsummary import summary
  184. flop, params = profile(model, input_size=(1, 3, 512, 512))
  185. print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop / 1e9, params / 1e6))