|
- """ Deep Feature Aggregation"""
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- from core.models.base_models import Enc, FCAttention, get_xception_a
- from core.nn import _ConvBNReLU
-
- __all__ = ['DFANet', 'get_dfanet', 'get_dfanet_citys']
-
-
- class DFANet(nn.Module):
- def __init__(self, nclass, backbone='', aux=False, jpu=False, pretrained_base=False, **kwargs):
- super(DFANet, self).__init__()
- self.pretrained = get_xception_a(pretrained_base, **kwargs)
-
- self.enc2_2 = Enc(240, 48, 4, **kwargs)
- self.enc3_2 = Enc(144, 96, 6, **kwargs)
- self.enc4_2 = Enc(288, 192, 4, **kwargs)
- self.fca_2 = FCAttention(192, **kwargs)
-
- self.enc2_3 = Enc(240, 48, 4, **kwargs)
- self.enc3_3 = Enc(144, 96, 6, **kwargs)
- self.enc3_4 = Enc(288, 192, 4, **kwargs)
- self.fca_3 = FCAttention(192, **kwargs)
-
- self.enc2_1_reduce = _ConvBNReLU(48, 32, 1, **kwargs)
- self.enc2_2_reduce = _ConvBNReLU(48, 32, 1, **kwargs)
- self.enc2_3_reduce = _ConvBNReLU(48, 32, 1, **kwargs)
- self.conv_fusion = _ConvBNReLU(32, 32, 1, **kwargs)
-
- self.fca_1_reduce = _ConvBNReLU(192, 32, 1, **kwargs)
- self.fca_2_reduce = _ConvBNReLU(192, 32, 1, **kwargs)
- self.fca_3_reduce = _ConvBNReLU(192, 32, 1, **kwargs)
- self.conv_out = nn.Conv2d(32, nclass, 1)
-
- self.__setattr__('exclusive', ['enc2_2', 'enc3_2', 'enc4_2', 'fca_2', 'enc2_3', 'enc3_3', 'enc3_4', 'fca_3',
- 'enc2_1_reduce', 'enc2_2_reduce', 'enc2_3_reduce', 'conv_fusion', 'fca_1_reduce',
- 'fca_2_reduce', 'fca_3_reduce', 'conv_out'])
-
- def forward(self, x):
- # backbone
- stage1_conv1 = self.pretrained.conv1(x)
- stage1_enc2 = self.pretrained.enc2(stage1_conv1)
- stage1_enc3 = self.pretrained.enc3(stage1_enc2)
- stage1_enc4 = self.pretrained.enc4(stage1_enc3)
- stage1_fca = self.pretrained.fca(stage1_enc4)
- stage1_out = F.interpolate(stage1_fca, scale_factor=4, mode='bilinear', align_corners=True)
-
- # stage2
- stage2_enc2 = self.enc2_2(torch.cat([stage1_enc2, stage1_out], dim=1))
- stage2_enc3 = self.enc3_2(torch.cat([stage1_enc3, stage2_enc2], dim=1))
- stage2_enc4 = self.enc4_2(torch.cat([stage1_enc4, stage2_enc3], dim=1))
- stage2_fca = self.fca_2(stage2_enc4)
- stage2_out = F.interpolate(stage2_fca, scale_factor=4, mode='bilinear', align_corners=True)
-
- # stage3
- stage3_enc2 = self.enc2_3(torch.cat([stage2_enc2, stage2_out], dim=1))
- stage3_enc3 = self.enc3_3(torch.cat([stage2_enc3, stage3_enc2], dim=1))
- stage3_enc4 = self.enc3_4(torch.cat([stage2_enc4, stage3_enc3], dim=1))
- stage3_fca = self.fca_3(stage3_enc4)
-
- stage1_enc2_decoder = self.enc2_1_reduce(stage1_enc2)
- stage2_enc2_docoder = F.interpolate(self.enc2_2_reduce(stage2_enc2), scale_factor=2,
- mode='bilinear', align_corners=True)
- stage3_enc2_decoder = F.interpolate(self.enc2_3_reduce(stage3_enc2), scale_factor=4,
- mode='bilinear', align_corners=True)
- fusion = stage1_enc2_decoder + stage2_enc2_docoder + stage3_enc2_decoder
- fusion1 = self.conv_fusion(fusion)
-
- stage1_fca_decoder = F.interpolate(self.fca_1_reduce(stage1_fca), scale_factor=4,
- mode='bilinear', align_corners=True)
- stage2_fca_decoder = F.interpolate(self.fca_2_reduce(stage2_fca), scale_factor=8,
- mode='bilinear', align_corners=True)
- stage3_fca_decoder = F.interpolate(self.fca_3_reduce(stage3_fca), scale_factor=16,
- mode='bilinear', align_corners=True)
- #print(fusion.shape,stage1_fca_decoder.shape,stage2_fca_decoder.shape,stage3_fca_decoder.shape)
- fusion2 = fusion1 + stage1_fca_decoder + stage2_fca_decoder + stage3_fca_decoder
-
- outputs = list()
- out = self.conv_out(fusion2)
- out1 = F.interpolate(out, scale_factor=4, mode='bilinear', align_corners=True)
- outputs.append(out1)
-
- #return tuple(outputs)
- return outputs[0]
-
- def get_dfanet(dataset='citys', backbone='', pretrained=False, root='~/.torch/models',
- pretrained_base=True, **kwargs):
- acronyms = {
- 'pascal_voc': 'pascal_voc',
- 'pascal_aug': 'pascal_aug',
- 'ade20k': 'ade',
- 'coco': 'coco',
- 'citys': 'citys',
- }
- from ..data.dataloader import datasets
- model = DFANet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
- if pretrained:
- from .model_store import get_model_file
- device = torch.device(kwargs['local_rank'])
- model.load_state_dict(torch.load(get_model_file('dfanet_%s' % (acronyms[dataset]), root=root),
- map_location=device))
- return model
-
-
- def get_dfanet_citys(**kwargs):
- return get_dfanet('citys', **kwargs)
-
-
- if __name__ == '__main__':
- #model = get_dfanet_citys()
- input = torch.rand(2, 3, 512, 512)
- model = DFANet(4, pretrained_base=False)
- # target = torch.zeros(4, 512, 512).cuda()
- # model.eval()
- # print(model)
- loss = model(input)
- print(loss, loss.shape)
-
- # from torchsummary import summary
- #
- # summary(model, (3, 224, 224)) # 打印表格,按顺序输出每层的输出形状和参数
- import torch
- from thop import profile
- from torchsummary import summary
-
- flop, params = profile(model, input_size=(1, 3, 512, 512))
- print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop / 1e9, params / 1e6))
|