用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

198 lines
7.9KB

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from core.models.base_models.densenet import *
  5. from core.models.fcn import _FCNHead
  6. __all__ = ['DenseASPP', 'get_denseaspp', 'get_denseaspp_densenet121_citys',
  7. 'get_denseaspp_densenet161_citys', 'get_denseaspp_densenet169_citys', 'get_denseaspp_densenet201_citys']
  8. class DenseASPP(nn.Module):
  9. def __init__(self, nclass, backbone='densenet121', aux=False, jpu=False,
  10. pretrained_base=True, dilate_scale=8, **kwargs):
  11. super(DenseASPP, self).__init__()
  12. self.nclass = nclass
  13. self.aux = aux
  14. self.dilate_scale = dilate_scale
  15. if backbone == 'densenet121':
  16. self.pretrained = dilated_densenet121(dilate_scale, pretrained=pretrained_base, **kwargs)
  17. elif backbone == 'densenet161':
  18. self.pretrained = dilated_densenet161(dilate_scale, pretrained=pretrained_base, **kwargs)
  19. elif backbone == 'densenet169':
  20. self.pretrained = dilated_densenet169(dilate_scale, pretrained=pretrained_base, **kwargs)
  21. elif backbone == 'densenet201':
  22. self.pretrained = dilated_densenet201(dilate_scale, pretrained=pretrained_base, **kwargs)
  23. else:
  24. raise RuntimeError('unknown backbone: {}'.format(backbone))
  25. in_channels = self.pretrained.num_features
  26. self.head = _DenseASPPHead(in_channels, nclass)
  27. if aux:
  28. self.auxlayer = _FCNHead(in_channels, nclass, **kwargs)
  29. self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
  30. def forward(self, x):
  31. size = x.size()[2:]
  32. #print('size', size) #torch.Size([512, 512])
  33. features = self.pretrained.features(x)
  34. #print('22',features.shape) #torch.Size([2, 1024, 64, 64])
  35. if self.dilate_scale > 8:
  36. features = F.interpolate(features, scale_factor=2, mode='bilinear', align_corners=True)
  37. outputs = []
  38. x = self.head(features) #torch.Size([2, 4, 64, 64])
  39. #print('x.shape',x.shape)
  40. x = F.interpolate(x, size, mode='bilinear', align_corners=True)#直接64到512。。。。效果还这么好!
  41. outputs.append(x)
  42. if self.aux:
  43. auxout = self.auxlayer(features)
  44. auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
  45. outputs.append(auxout)
  46. #return tuple(outputs)
  47. return outputs[0]
  48. class _DenseASPPHead(nn.Module):
  49. def __init__(self, in_channels, nclass, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
  50. super(_DenseASPPHead, self).__init__()
  51. self.dense_aspp_block = _DenseASPPBlock(in_channels, 256, 64, norm_layer, norm_kwargs)
  52. self.block = nn.Sequential(
  53. nn.Dropout(0.1),
  54. nn.Conv2d(in_channels + 5 * 64, nclass, 1)
  55. )
  56. def forward(self, x):
  57. x = self.dense_aspp_block(x)
  58. return self.block(x)
  59. class _DenseASPPConv(nn.Sequential):
  60. def __init__(self, in_channels, inter_channels, out_channels, atrous_rate,
  61. drop_rate=0.1, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
  62. super(_DenseASPPConv, self).__init__()
  63. self.add_module('conv1', nn.Conv2d(in_channels, inter_channels, 1)),
  64. self.add_module('bn1', norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs))),
  65. self.add_module('relu1', nn.ReLU(True)),
  66. self.add_module('conv2', nn.Conv2d(inter_channels, out_channels, 3, dilation=atrous_rate, padding=atrous_rate)),
  67. self.add_module('bn2', norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs))),
  68. self.add_module('relu2', nn.ReLU(True)),
  69. self.drop_rate = drop_rate
  70. def forward(self, x):
  71. features = super(_DenseASPPConv, self).forward(x)
  72. if self.drop_rate > 0:
  73. features = F.dropout(features, p=self.drop_rate, training=self.training)
  74. return features
  75. class _DenseASPPBlock(nn.Module):
  76. def __init__(self, in_channels, inter_channels1, inter_channels2,
  77. norm_layer=nn.BatchNorm2d, norm_kwargs=None):
  78. super(_DenseASPPBlock, self).__init__()
  79. self.aspp_3 = _DenseASPPConv(in_channels, inter_channels1, inter_channels2, 3, 0.1,
  80. norm_layer, norm_kwargs)
  81. self.aspp_6 = _DenseASPPConv(in_channels + inter_channels2 * 1, inter_channels1, inter_channels2, 6, 0.1,
  82. norm_layer, norm_kwargs)
  83. self.aspp_12 = _DenseASPPConv(in_channels + inter_channels2 * 2, inter_channels1, inter_channels2, 12, 0.1,
  84. norm_layer, norm_kwargs)
  85. self.aspp_18 = _DenseASPPConv(in_channels + inter_channels2 * 3, inter_channels1, inter_channels2, 18, 0.1,
  86. norm_layer, norm_kwargs)
  87. self.aspp_24 = _DenseASPPConv(in_channels + inter_channels2 * 4, inter_channels1, inter_channels2, 24, 0.1,
  88. norm_layer, norm_kwargs)
  89. def forward(self, x):
  90. aspp3 = self.aspp_3(x)
  91. x = torch.cat([aspp3, x], dim=1)
  92. aspp6 = self.aspp_6(x)
  93. x = torch.cat([aspp6, x], dim=1)
  94. aspp12 = self.aspp_12(x)
  95. x = torch.cat([aspp12, x], dim=1)
  96. aspp18 = self.aspp_18(x)
  97. x = torch.cat([aspp18, x], dim=1)
  98. aspp24 = self.aspp_24(x)
  99. x = torch.cat([aspp24, x], dim=1)
  100. return x
  101. def get_denseaspp(dataset='citys', backbone='densenet121', pretrained=False,
  102. root='~/.torch/models', pretrained_base=True, **kwargs):
  103. r"""DenseASPP
  104. Parameters
  105. ----------
  106. dataset : str, default citys
  107. The dataset that model pretrained on. (pascal_voc, ade20k)
  108. pretrained : bool or str
  109. Boolean value controls whether to load the default pretrained weights for model.
  110. String value represents the hashtag for a certain version of pretrained weights.
  111. root : str, default '~/.torch/models'
  112. Location for keeping the model parameters.
  113. pretrained_base : bool or str, default True
  114. This will load pretrained backbone network, that was trained on ImageNet.
  115. Examples
  116. --------
  117. # >>> model = get_denseaspp(dataset='citys', backbone='densenet121', pretrained=False)
  118. # >>> print(model)
  119. """
  120. acronyms = {
  121. 'pascal_voc': 'pascal_voc',
  122. 'pascal_aug': 'pascal_aug',
  123. 'ade20k': 'ade',
  124. 'coco': 'coco',
  125. 'citys': 'citys',
  126. }
  127. from ..data.dataloader import datasets
  128. model = DenseASPP(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
  129. if pretrained:
  130. from .model_store import get_model_file
  131. device = torch.device(kwargs['local_rank'])
  132. model.load_state_dict(torch.load(get_model_file('denseaspp_%s_%s' % (backbone, acronyms[dataset]), root=root),
  133. map_location=device))
  134. return model
  135. def get_denseaspp_densenet121_citys(**kwargs):
  136. return get_denseaspp('citys', 'densenet121', **kwargs)
  137. def get_denseaspp_densenet161_citys(**kwargs):
  138. return get_denseaspp('citys', 'densenet161', **kwargs)
  139. def get_denseaspp_densenet169_citys(**kwargs):
  140. return get_denseaspp('citys', 'densenet169', **kwargs)
  141. def get_denseaspp_densenet201_citys(**kwargs):
  142. return get_denseaspp('citys', 'densenet201', **kwargs)
  143. if __name__ == '__main__':
  144. # img = torch.randn(2, 3, 480, 480)
  145. # model = get_denseaspp_densenet121_citys()
  146. # outputs = model(img)
  147. input = torch.rand(2, 3, 512, 512)
  148. model = DenseASPP(4, pretrained_base=True)
  149. # target = torch.zeros(4, 512, 512).cuda()
  150. # model.eval()
  151. # print(model)
  152. loss = model(input)
  153. print(loss, loss.shape)
  154. # from torchsummary import summary
  155. #
  156. # summary(model, (3, 224, 224)) # 打印表格,按顺序输出每层的输出形状和参数
  157. import torch
  158. from thop import profile
  159. from torchsummary import summary
  160. flop, params = profile(model, input_size=(1, 3, 512, 512))
  161. print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop / 1e9, params / 1e6))