用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

172 lines
6.1KB

  1. """Decoders Matter for Semantic Segmentation"""
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from core.models.segbase import SegBaseModel
  6. from core.models.fcn import _FCNHead
  7. __all__ = ['DUNet', 'get_dunet', 'get_dunet_resnet50_pascal_voc',
  8. 'get_dunet_resnet101_pascal_voc', 'get_dunet_resnet152_pascal_voc']
  9. # The model may be wrong because lots of details missing in paper.
  10. class DUNet(SegBaseModel):
  11. """Decoders Matter for Semantic Segmentation
  12. Reference:
  13. Zhi Tian, Tong He, Chunhua Shen, and Youliang Yan.
  14. "Decoders Matter for Semantic Segmentation:
  15. Data-Dependent Decoding Enables Flexible Feature Aggregation." CVPR, 2019
  16. """
  17. def __init__(self, nclass, backbone='resnet50', aux=True, pretrained_base=True, **kwargs):
  18. super(DUNet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
  19. self.head = _DUHead(2144, **kwargs)
  20. self.dupsample = DUpsampling(256, nclass, scale_factor=8, **kwargs)
  21. if aux:
  22. self.auxlayer = _FCNHead(1024, 256, **kwargs)
  23. self.aux_dupsample = DUpsampling(256, nclass, scale_factor=8, **kwargs)
  24. self.__setattr__('exclusive',
  25. ['dupsample', 'head', 'auxlayer', 'aux_dupsample'] if aux else ['dupsample', 'head'])
  26. def forward(self, x):
  27. c1, c2, c3, c4 = self.base_forward(x)#继承自SegBaseModel;返回的是resnet的layer1,2,3,4的输出
  28. outputs = []
  29. x = self.head(c2, c3, c4)
  30. x = self.dupsample(x)
  31. outputs.append(x)
  32. if self.aux:
  33. auxout = self.auxlayer(c3)
  34. auxout = self.aux_dupsample(auxout)
  35. outputs.append(auxout)
  36. #return tuple(outputs)
  37. return outputs[0]
  38. class FeatureFused(nn.Module):
  39. """Module for fused features"""
  40. def __init__(self, inter_channels=48, norm_layer=nn.BatchNorm2d, **kwargs):
  41. super(FeatureFused, self).__init__()
  42. self.conv2 = nn.Sequential(
  43. nn.Conv2d(512, inter_channels, 1, bias=False),
  44. norm_layer(inter_channels),
  45. nn.ReLU(True)
  46. )
  47. self.conv3 = nn.Sequential(
  48. nn.Conv2d(1024, inter_channels, 1, bias=False),
  49. norm_layer(inter_channels),
  50. nn.ReLU(True)
  51. )
  52. def forward(self, c2, c3, c4):
  53. size = c4.size()[2:]
  54. c2 = self.conv2(F.interpolate(c2, size, mode='bilinear', align_corners=True))
  55. c3 = self.conv3(F.interpolate(c3, size, mode='bilinear', align_corners=True))
  56. fused_feature = torch.cat([c4, c3, c2], dim=1)
  57. return fused_feature
  58. class _DUHead(nn.Module):
  59. def __init__(self, in_channels, norm_layer=nn.BatchNorm2d, **kwargs):
  60. super(_DUHead, self).__init__()
  61. self.fuse = FeatureFused(norm_layer=norm_layer, **kwargs)
  62. self.block = nn.Sequential(
  63. nn.Conv2d(in_channels, 256, 3, padding=1, bias=False),
  64. norm_layer(256),
  65. nn.ReLU(True),
  66. nn.Conv2d(256, 256, 3, padding=1, bias=False),
  67. norm_layer(256),
  68. nn.ReLU(True)
  69. )
  70. def forward(self, c2, c3, c4):
  71. fused_feature = self.fuse(c2, c3, c4)
  72. out = self.block(fused_feature)
  73. return out
  74. class DUpsampling(nn.Module):
  75. """DUsampling module"""
  76. def __init__(self, in_channels, out_channels, scale_factor=2, **kwargs):
  77. super(DUpsampling, self).__init__()
  78. self.scale_factor = scale_factor
  79. self.conv_w = nn.Conv2d(in_channels, out_channels * scale_factor * scale_factor, 1, bias=False)
  80. def forward(self, x):
  81. x = self.conv_w(x)
  82. n, c, h, w = x.size()
  83. # N, C, H, W --> N, W, H, C
  84. x = x.permute(0, 3, 2, 1).contiguous()
  85. # N, W, H, C --> N, W, H * scale, C // scale
  86. x = x.view(n, w, h * self.scale_factor, c // self.scale_factor)
  87. # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
  88. x = x.permute(0, 2, 1, 3).contiguous()
  89. # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
  90. x = x.view(n, h * self.scale_factor, w * self.scale_factor, c // (self.scale_factor * self.scale_factor))
  91. # N, H * scale, W * scale, C // (scale ** 2) -- > N, C // (scale ** 2), H * scale, W * scale
  92. x = x.permute(0, 3, 1, 2)
  93. return x
  94. def get_dunet(dataset='pascal_voc', backbone='resnet50', pretrained=False,
  95. root='~/.torch/models', pretrained_base=True, **kwargs):
  96. acronyms = {
  97. 'pascal_voc': 'pascal_voc',
  98. 'pascal_aug': 'pascal_aug',
  99. 'ade20k': 'ade',
  100. 'coco': 'coco',
  101. 'citys': 'citys',
  102. }
  103. from ..data.dataloader import datasets
  104. model = DUNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
  105. if pretrained:
  106. from .model_store import get_model_file
  107. device = torch.device(kwargs['local_rank'])
  108. model.load_state_dict(torch.load(get_model_file('dunet_%s_%s' % (backbone, acronyms[dataset]), root=root),
  109. map_location=device))
  110. return model
  111. def get_dunet_resnet50_pascal_voc(**kwargs):
  112. return get_dunet('pascal_voc', 'resnet50', **kwargs)
  113. def get_dunet_resnet101_pascal_voc(**kwargs):
  114. return get_dunet('pascal_voc', 'resnet101', **kwargs)
  115. def get_dunet_resnet152_pascal_voc(**kwargs):
  116. return get_dunet('pascal_voc', 'resnet152', **kwargs)
  117. if __name__ == '__main__':
  118. # img = torch.randn(2, 3, 256, 256)
  119. # model = get_dunet_resnet50_pascal_voc()
  120. # outputs = model(img)
  121. input = torch.rand(2, 3, 224, 224)
  122. model = DUNet(4, pretrained_base=False)
  123. # target = torch.zeros(4, 512, 512).cuda()
  124. # model.eval()
  125. # print(model)
  126. loss = model(input)
  127. print(loss, loss.shape)
  128. # from torchsummary import summary
  129. #
  130. # summary(model, (3, 224, 224)) # 打印表格,按顺序输出每层的输出形状和参数
  131. import torch
  132. from thop import profile
  133. from torchsummary import summary
  134. input = torch.randn(1, 3, 512, 512)
  135. flop, params = profile(model, inputs=(input, ))
  136. print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop / 1e9, params / 1e6))