用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

186 lines
6.4KB

  1. """Pyramid Scene Parsing Network"""
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from .segbase import SegBaseModel
  6. from .fcn import _FCNHead
  7. __all__ = ['DeepLabV3', 'get_deeplabv3', 'get_deeplabv3_resnet50_voc', 'get_deeplabv3_resnet101_voc',
  8. 'get_deeplabv3_resnet152_voc', 'get_deeplabv3_resnet50_ade', 'get_deeplabv3_resnet101_ade',
  9. 'get_deeplabv3_resnet152_ade']
  10. class DeepLabV3(SegBaseModel):
  11. r"""DeepLabV3
  12. Parameters
  13. ----------
  14. nclass : int
  15. Number of categories for the training dataset.
  16. backbone : string
  17. Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
  18. 'resnet101' or 'resnet152').
  19. norm_layer : object
  20. Normalization layer used in backbone network (default: :class:`nn.BatchNorm`;
  21. for Synchronized Cross-GPU BachNormalization).
  22. aux : bool
  23. Auxiliary loss.
  24. Reference:
  25. Chen, Liang-Chieh, et al. "Rethinking atrous convolution for semantic image segmentation."
  26. arXiv preprint arXiv:1706.05587 (2017).
  27. """
  28. def __init__(self, nclass, backbone='resnet50', aux=False, pretrained_base=True, **kwargs):
  29. super(DeepLabV3, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
  30. self.head = _DeepLabHead(nclass, **kwargs)
  31. if self.aux:
  32. self.auxlayer = _FCNHead(1024, nclass, **kwargs)
  33. self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
  34. def forward(self, x):
  35. size = x.size()[2:]
  36. _, _, c3, c4 = self.base_forward(x)
  37. outputs = []
  38. x = self.head(c4)
  39. x = F.interpolate(x, size, mode='bilinear', align_corners=True)
  40. outputs.append(x)
  41. if self.aux:
  42. auxout = self.auxlayer(c3)
  43. auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
  44. outputs.append(auxout)
  45. return tuple(outputs)
  46. class _DeepLabHead(nn.Module):
  47. def __init__(self, nclass, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
  48. super(_DeepLabHead, self).__init__()
  49. self.aspp = _ASPP(2048, [12, 24, 36], norm_layer=norm_layer, norm_kwargs=norm_kwargs, **kwargs)
  50. self.block = nn.Sequential(
  51. nn.Conv2d(256, 256, 3, padding=1, bias=False),
  52. norm_layer(256, **({} if norm_kwargs is None else norm_kwargs)),
  53. nn.ReLU(True),
  54. nn.Dropout(0.1),
  55. nn.Conv2d(256, nclass, 1)
  56. )
  57. def forward(self, x):
  58. x = self.aspp(x)
  59. return self.block(x)
  60. class _ASPPConv(nn.Module):
  61. def __init__(self, in_channels, out_channels, atrous_rate, norm_layer, norm_kwargs):
  62. super(_ASPPConv, self).__init__()
  63. self.block = nn.Sequential(
  64. nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rate, dilation=atrous_rate, bias=False),
  65. norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
  66. nn.ReLU(True)
  67. )
  68. def forward(self, x):
  69. return self.block(x)
  70. class _AsppPooling(nn.Module):
  71. def __init__(self, in_channels, out_channels, norm_layer, norm_kwargs, **kwargs):
  72. super(_AsppPooling, self).__init__()
  73. self.gap = nn.Sequential(
  74. nn.AdaptiveAvgPool2d(1),
  75. nn.Conv2d(in_channels, out_channels, 1, bias=False),
  76. norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
  77. nn.ReLU(True)
  78. )
  79. def forward(self, x):
  80. size = x.size()[2:]
  81. pool = self.gap(x)
  82. out = F.interpolate(pool, size, mode='bilinear', align_corners=True)
  83. return out
  84. class _ASPP(nn.Module):
  85. def __init__(self, in_channels, atrous_rates, norm_layer, norm_kwargs, **kwargs):
  86. super(_ASPP, self).__init__()
  87. out_channels = 256
  88. self.b0 = nn.Sequential(
  89. nn.Conv2d(in_channels, out_channels, 1, bias=False),
  90. norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
  91. nn.ReLU(True)
  92. )
  93. rate1, rate2, rate3 = tuple(atrous_rates)
  94. self.b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer, norm_kwargs)
  95. self.b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer, norm_kwargs)
  96. self.b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer, norm_kwargs)
  97. self.b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
  98. self.project = nn.Sequential(
  99. nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
  100. norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
  101. nn.ReLU(True),
  102. nn.Dropout(0.5)
  103. )
  104. def forward(self, x):
  105. feat1 = self.b0(x)
  106. feat2 = self.b1(x)
  107. feat3 = self.b2(x)
  108. feat4 = self.b3(x)
  109. feat5 = self.b4(x)
  110. x = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
  111. x = self.project(x)
  112. return x
  113. def get_deeplabv3(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~/.torch/models',
  114. pretrained_base=True, **kwargs):
  115. acronyms = {
  116. 'pascal_voc': 'pascal_voc',
  117. 'pascal_aug': 'pascal_aug',
  118. 'ade20k': 'ade',
  119. 'coco': 'coco',
  120. 'citys': 'citys',
  121. }
  122. from ..data.dataloader import datasets
  123. model = DeepLabV3(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
  124. if pretrained:
  125. from .model_store import get_model_file
  126. device = torch.device(kwargs['local_rank'])
  127. model.load_state_dict(torch.load(get_model_file('deeplabv3_%s_%s' % (backbone, acronyms[dataset]), root=root),
  128. map_location=device))
  129. return model
  130. def get_deeplabv3_resnet50_voc(**kwargs):
  131. return get_deeplabv3('pascal_voc', 'resnet50', **kwargs)
  132. def get_deeplabv3_resnet101_voc(**kwargs):
  133. return get_deeplabv3('pascal_voc', 'resnet101', **kwargs)
  134. def get_deeplabv3_resnet152_voc(**kwargs):
  135. return get_deeplabv3('pascal_voc', 'resnet152', **kwargs)
  136. def get_deeplabv3_resnet50_ade(**kwargs):
  137. return get_deeplabv3('ade20k', 'resnet50', **kwargs)
  138. def get_deeplabv3_resnet101_ade(**kwargs):
  139. return get_deeplabv3('ade20k', 'resnet101', **kwargs)
  140. def get_deeplabv3_resnet152_ade(**kwargs):
  141. return get_deeplabv3('ade20k', 'resnet152', **kwargs)
  142. if __name__ == '__main__':
  143. model = get_deeplabv3_resnet50_voc()
  144. img = torch.randn(2, 3, 480, 480)
  145. output = model(img)