用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

185 lines
6.6KB

  1. """Pyramid Scene Parsing Network"""
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from core.models.segbase import SegBaseModel
  6. from core.models.fcn import _FCNHead
  7. __all__ = ['PSPNet', 'get_psp', 'get_psp_resnet50_voc', 'get_psp_resnet50_ade', 'get_psp_resnet101_voc',
  8. 'get_psp_resnet101_ade', 'get_psp_resnet101_citys', 'get_psp_resnet101_coco']
  9. class PSPNet(SegBaseModel):
  10. r"""Pyramid Scene Parsing Network
  11. Parameters
  12. ----------
  13. nclass : int
  14. Number of categories for the training dataset.
  15. backbone : string
  16. Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
  17. 'resnet101' or 'resnet152').
  18. norm_layer : object
  19. Normalization layer used in backbone network (default: :class:`nn.BatchNorm`;
  20. for Synchronized Cross-GPU BachNormalization).
  21. aux : bool
  22. Auxiliary loss.
  23. Reference:
  24. Zhao, Hengshuang, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, and Jiaya Jia.
  25. "Pyramid scene parsing network." *CVPR*, 2017
  26. """
  27. def __init__(self, nclass, backbone='resnet50', aux=False, pretrained_base=True, **kwargs):
  28. super(PSPNet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
  29. self.head = _PSPHead(nclass, **kwargs)
  30. if self.aux:
  31. self.auxlayer = _FCNHead(1024, nclass, **kwargs)
  32. self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
  33. def forward(self, x):
  34. size = x.size()[2:]
  35. _, _, c3, c4 = self.base_forward(x)
  36. outputs = []
  37. x = self.head(c4)
  38. x = F.interpolate(x, size, mode='bilinear', align_corners=True)
  39. outputs.append(x)
  40. if self.aux:
  41. auxout = self.auxlayer(c3)
  42. auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
  43. outputs.append(auxout)
  44. #return tuple(outputs)
  45. return outputs[0]
  46. def _PSP1x1Conv(in_channels, out_channels, norm_layer, norm_kwargs):
  47. return nn.Sequential(
  48. nn.Conv2d(in_channels, out_channels, 1, bias=False),
  49. norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
  50. nn.ReLU(True)
  51. )
  52. class _PyramidPooling(nn.Module):
  53. def __init__(self, in_channels, **kwargs):
  54. super(_PyramidPooling, self).__init__()
  55. out_channels = int(in_channels / 4)
  56. self.avgpool1 = nn.AdaptiveAvgPool2d(1)
  57. self.avgpool2 = nn.AdaptiveAvgPool2d(2)
  58. self.avgpool3 = nn.AdaptiveAvgPool2d(3)
  59. self.avgpool4 = nn.AdaptiveAvgPool2d(6)
  60. self.conv1 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
  61. self.conv2 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
  62. self.conv3 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
  63. self.conv4 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
  64. def forward(self, x):
  65. size = x.size()[2:]
  66. feat1 = F.interpolate(self.conv1(self.avgpool1(x)), size, mode='bilinear', align_corners=True)
  67. feat2 = F.interpolate(self.conv2(self.avgpool2(x)), size, mode='bilinear', align_corners=True)
  68. feat3 = F.interpolate(self.conv3(self.avgpool3(x)), size, mode='bilinear', align_corners=True)
  69. feat4 = F.interpolate(self.conv4(self.avgpool4(x)), size, mode='bilinear', align_corners=True)
  70. return torch.cat([x, feat1, feat2, feat3, feat4], dim=1)
  71. class _PSPHead(nn.Module):
  72. def __init__(self, nclass, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
  73. super(_PSPHead, self).__init__()
  74. self.psp = _PyramidPooling(2048, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
  75. self.block = nn.Sequential(
  76. nn.Conv2d(4096, 512, 3, padding=1, bias=False),
  77. norm_layer(512, **({} if norm_kwargs is None else norm_kwargs)),
  78. nn.ReLU(True),
  79. nn.Dropout(0.1),
  80. nn.Conv2d(512, nclass, 1)
  81. )
  82. def forward(self, x):
  83. x = self.psp(x)
  84. return self.block(x)
  85. def get_psp(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~/.torch/models',
  86. pretrained_base=True, **kwargs):
  87. r"""Pyramid Scene Parsing Network
  88. Parameters
  89. ----------
  90. dataset : str, default pascal_voc
  91. The dataset that model pretrained on. (pascal_voc, ade20k)
  92. pretrained : bool or str
  93. Boolean value controls whether to load the default pretrained weights for model.
  94. String value represents the hashtag for a certain version of pretrained weights.
  95. root : str, default '~/.torch/models'
  96. Location for keeping the model parameters.
  97. pretrained_base : bool or str, default True
  98. This will load pretrained backbone network, that was trained on ImageNet.
  99. Examples
  100. --------
  101. >>> model = get_psp(dataset='pascal_voc', backbone='resnet50', pretrained=False)
  102. >>> print(model)
  103. """
  104. acronyms = {
  105. 'pascal_voc': 'pascal_voc',
  106. 'pascal_aug': 'pascal_aug',
  107. 'ade20k': 'ade',
  108. 'coco': 'coco',
  109. 'citys': 'citys',
  110. }
  111. from ..data.dataloader import datasets
  112. model = PSPNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
  113. if pretrained:
  114. from .model_store import get_model_file
  115. device = torch.device(kwargs['local_rank'])
  116. model.load_state_dict(torch.load(get_model_file('psp_%s_%s' % (backbone, acronyms[dataset]), root=root),
  117. map_location=device))
  118. return model
  119. def get_psp_resnet50_voc(**kwargs):
  120. return get_psp('pascal_voc', 'resnet50', **kwargs)
  121. def get_psp_resnet50_ade(**kwargs):
  122. return get_psp('ade20k', 'resnet50', **kwargs)
  123. def get_psp_resnet101_voc(**kwargs):
  124. return get_psp('pascal_voc', 'resnet101', **kwargs)
  125. def get_psp_resnet101_ade(**kwargs):
  126. return get_psp('ade20k', 'resnet101', **kwargs)
  127. def get_psp_resnet101_citys(**kwargs):
  128. return get_psp('citys', 'resnet101', **kwargs)
  129. def get_psp_resnet101_coco(**kwargs):
  130. return get_psp('coco', 'resnet101', **kwargs)
  131. if __name__ == '__main__':
  132. # model = get_psp_resnet50_voc()
  133. # img = torch.randn(4, 3, 480, 480)
  134. # output = model(img)
  135. input = torch.rand(2, 3, 512, 512)
  136. model = PSPNet(4, pretrained_base=False)
  137. # target = torch.zeros(4, 512, 512).cuda()
  138. # model.eval()
  139. # print(model)
  140. loss = model(input)
  141. print(loss, loss.shape)
  142. # from torchsummary import summary
  143. #
  144. # summary(model, (3, 224, 224)) # 打印表格,按顺序输出每层的输出形状和参数
  145. import torch
  146. from thop import profile
  147. from torchsummary import summary
  148. flop, params = profile(model, input_size=(1, 3, 512, 512))
  149. print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop / 1e9, params / 1e6))