用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

130 lines
5.6KB

  1. """ Deep Feature Aggregation"""
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from core.models.base_models import Enc, FCAttention, get_xception_a
  6. from core.nn import _ConvBNReLU
  7. __all__ = ['DFANet', 'get_dfanet', 'get_dfanet_citys']
  8. class DFANet(nn.Module):
  9. def __init__(self, nclass, backbone='', aux=False, jpu=False, pretrained_base=False, **kwargs):
  10. super(DFANet, self).__init__()
  11. self.pretrained = get_xception_a(pretrained_base, **kwargs)
  12. self.enc2_2 = Enc(240, 48, 4, **kwargs)
  13. self.enc3_2 = Enc(144, 96, 6, **kwargs)
  14. self.enc4_2 = Enc(288, 192, 4, **kwargs)
  15. self.fca_2 = FCAttention(192, **kwargs)
  16. self.enc2_3 = Enc(240, 48, 4, **kwargs)
  17. self.enc3_3 = Enc(144, 96, 6, **kwargs)
  18. self.enc3_4 = Enc(288, 192, 4, **kwargs)
  19. self.fca_3 = FCAttention(192, **kwargs)
  20. self.enc2_1_reduce = _ConvBNReLU(48, 32, 1, **kwargs)
  21. self.enc2_2_reduce = _ConvBNReLU(48, 32, 1, **kwargs)
  22. self.enc2_3_reduce = _ConvBNReLU(48, 32, 1, **kwargs)
  23. self.conv_fusion = _ConvBNReLU(32, 32, 1, **kwargs)
  24. self.fca_1_reduce = _ConvBNReLU(192, 32, 1, **kwargs)
  25. self.fca_2_reduce = _ConvBNReLU(192, 32, 1, **kwargs)
  26. self.fca_3_reduce = _ConvBNReLU(192, 32, 1, **kwargs)
  27. self.conv_out = nn.Conv2d(32, nclass, 1)
  28. self.__setattr__('exclusive', ['enc2_2', 'enc3_2', 'enc4_2', 'fca_2', 'enc2_3', 'enc3_3', 'enc3_4', 'fca_3',
  29. 'enc2_1_reduce', 'enc2_2_reduce', 'enc2_3_reduce', 'conv_fusion', 'fca_1_reduce',
  30. 'fca_2_reduce', 'fca_3_reduce', 'conv_out'])
  31. def forward(self, x):
  32. # backbone
  33. stage1_conv1 = self.pretrained.conv1(x)
  34. stage1_enc2 = self.pretrained.enc2(stage1_conv1)
  35. stage1_enc3 = self.pretrained.enc3(stage1_enc2)
  36. stage1_enc4 = self.pretrained.enc4(stage1_enc3)
  37. stage1_fca = self.pretrained.fca(stage1_enc4)
  38. stage1_out = F.interpolate(stage1_fca, scale_factor=4, mode='bilinear', align_corners=True)
  39. # stage2
  40. stage2_enc2 = self.enc2_2(torch.cat([stage1_enc2, stage1_out], dim=1))
  41. stage2_enc3 = self.enc3_2(torch.cat([stage1_enc3, stage2_enc2], dim=1))
  42. stage2_enc4 = self.enc4_2(torch.cat([stage1_enc4, stage2_enc3], dim=1))
  43. stage2_fca = self.fca_2(stage2_enc4)
  44. stage2_out = F.interpolate(stage2_fca, scale_factor=4, mode='bilinear', align_corners=True)
  45. # stage3
  46. stage3_enc2 = self.enc2_3(torch.cat([stage2_enc2, stage2_out], dim=1))
  47. stage3_enc3 = self.enc3_3(torch.cat([stage2_enc3, stage3_enc2], dim=1))
  48. stage3_enc4 = self.enc3_4(torch.cat([stage2_enc4, stage3_enc3], dim=1))
  49. stage3_fca = self.fca_3(stage3_enc4)
  50. stage1_enc2_decoder = self.enc2_1_reduce(stage1_enc2)
  51. stage2_enc2_docoder = F.interpolate(self.enc2_2_reduce(stage2_enc2), scale_factor=2,
  52. mode='bilinear', align_corners=True)
  53. stage3_enc2_decoder = F.interpolate(self.enc2_3_reduce(stage3_enc2), scale_factor=4,
  54. mode='bilinear', align_corners=True)
  55. fusion = stage1_enc2_decoder + stage2_enc2_docoder + stage3_enc2_decoder
  56. fusion1 = self.conv_fusion(fusion)
  57. stage1_fca_decoder = F.interpolate(self.fca_1_reduce(stage1_fca), scale_factor=4,
  58. mode='bilinear', align_corners=True)
  59. stage2_fca_decoder = F.interpolate(self.fca_2_reduce(stage2_fca), scale_factor=8,
  60. mode='bilinear', align_corners=True)
  61. stage3_fca_decoder = F.interpolate(self.fca_3_reduce(stage3_fca), scale_factor=16,
  62. mode='bilinear', align_corners=True)
  63. #print(fusion.shape,stage1_fca_decoder.shape,stage2_fca_decoder.shape,stage3_fca_decoder.shape)
  64. fusion2 = fusion1 + stage1_fca_decoder + stage2_fca_decoder + stage3_fca_decoder
  65. outputs = list()
  66. out = self.conv_out(fusion2)
  67. out1 = F.interpolate(out, scale_factor=4, mode='bilinear', align_corners=True)
  68. outputs.append(out1)
  69. #return tuple(outputs)
  70. return outputs[0]
  71. def get_dfanet(dataset='citys', backbone='', pretrained=False, root='~/.torch/models',
  72. pretrained_base=True, **kwargs):
  73. acronyms = {
  74. 'pascal_voc': 'pascal_voc',
  75. 'pascal_aug': 'pascal_aug',
  76. 'ade20k': 'ade',
  77. 'coco': 'coco',
  78. 'citys': 'citys',
  79. }
  80. from ..data.dataloader import datasets
  81. model = DFANet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
  82. if pretrained:
  83. from .model_store import get_model_file
  84. device = torch.device(kwargs['local_rank'])
  85. model.load_state_dict(torch.load(get_model_file('dfanet_%s' % (acronyms[dataset]), root=root),
  86. map_location=device))
  87. return model
  88. def get_dfanet_citys(**kwargs):
  89. return get_dfanet('citys', **kwargs)
  90. if __name__ == '__main__':
  91. #model = get_dfanet_citys()
  92. input = torch.rand(2, 3, 512, 512)
  93. model = DFANet(4, pretrained_base=False)
  94. # target = torch.zeros(4, 512, 512).cuda()
  95. # model.eval()
  96. # print(model)
  97. loss = model(input)
  98. print(loss, loss.shape)
  99. # from torchsummary import summary
  100. #
  101. # summary(model, (3, 224, 224)) # 打印表格,按顺序输出每层的输出形状和参数
  102. import torch
  103. from thop import profile
  104. from torchsummary import summary
  105. flop, params = profile(model, input_size=(1, 3, 512, 512))
  106. print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop / 1e9, params / 1e6))