130 lines
5.6 KiB
Python
130 lines
5.6 KiB
Python
""" Deep Feature Aggregation"""
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from core.models.base_models import Enc, FCAttention, get_xception_a
|
|
from core.nn import _ConvBNReLU
|
|
|
|
__all__ = ['DFANet', 'get_dfanet', 'get_dfanet_citys']
|
|
|
|
|
|
class DFANet(nn.Module):
|
|
def __init__(self, nclass, backbone='', aux=False, jpu=False, pretrained_base=False, **kwargs):
|
|
super(DFANet, self).__init__()
|
|
self.pretrained = get_xception_a(pretrained_base, **kwargs)
|
|
|
|
self.enc2_2 = Enc(240, 48, 4, **kwargs)
|
|
self.enc3_2 = Enc(144, 96, 6, **kwargs)
|
|
self.enc4_2 = Enc(288, 192, 4, **kwargs)
|
|
self.fca_2 = FCAttention(192, **kwargs)
|
|
|
|
self.enc2_3 = Enc(240, 48, 4, **kwargs)
|
|
self.enc3_3 = Enc(144, 96, 6, **kwargs)
|
|
self.enc3_4 = Enc(288, 192, 4, **kwargs)
|
|
self.fca_3 = FCAttention(192, **kwargs)
|
|
|
|
self.enc2_1_reduce = _ConvBNReLU(48, 32, 1, **kwargs)
|
|
self.enc2_2_reduce = _ConvBNReLU(48, 32, 1, **kwargs)
|
|
self.enc2_3_reduce = _ConvBNReLU(48, 32, 1, **kwargs)
|
|
self.conv_fusion = _ConvBNReLU(32, 32, 1, **kwargs)
|
|
|
|
self.fca_1_reduce = _ConvBNReLU(192, 32, 1, **kwargs)
|
|
self.fca_2_reduce = _ConvBNReLU(192, 32, 1, **kwargs)
|
|
self.fca_3_reduce = _ConvBNReLU(192, 32, 1, **kwargs)
|
|
self.conv_out = nn.Conv2d(32, nclass, 1)
|
|
|
|
self.__setattr__('exclusive', ['enc2_2', 'enc3_2', 'enc4_2', 'fca_2', 'enc2_3', 'enc3_3', 'enc3_4', 'fca_3',
|
|
'enc2_1_reduce', 'enc2_2_reduce', 'enc2_3_reduce', 'conv_fusion', 'fca_1_reduce',
|
|
'fca_2_reduce', 'fca_3_reduce', 'conv_out'])
|
|
|
|
def forward(self, x):
|
|
# backbone
|
|
stage1_conv1 = self.pretrained.conv1(x)
|
|
stage1_enc2 = self.pretrained.enc2(stage1_conv1)
|
|
stage1_enc3 = self.pretrained.enc3(stage1_enc2)
|
|
stage1_enc4 = self.pretrained.enc4(stage1_enc3)
|
|
stage1_fca = self.pretrained.fca(stage1_enc4)
|
|
stage1_out = F.interpolate(stage1_fca, scale_factor=4, mode='bilinear', align_corners=True)
|
|
|
|
# stage2
|
|
stage2_enc2 = self.enc2_2(torch.cat([stage1_enc2, stage1_out], dim=1))
|
|
stage2_enc3 = self.enc3_2(torch.cat([stage1_enc3, stage2_enc2], dim=1))
|
|
stage2_enc4 = self.enc4_2(torch.cat([stage1_enc4, stage2_enc3], dim=1))
|
|
stage2_fca = self.fca_2(stage2_enc4)
|
|
stage2_out = F.interpolate(stage2_fca, scale_factor=4, mode='bilinear', align_corners=True)
|
|
|
|
# stage3
|
|
stage3_enc2 = self.enc2_3(torch.cat([stage2_enc2, stage2_out], dim=1))
|
|
stage3_enc3 = self.enc3_3(torch.cat([stage2_enc3, stage3_enc2], dim=1))
|
|
stage3_enc4 = self.enc3_4(torch.cat([stage2_enc4, stage3_enc3], dim=1))
|
|
stage3_fca = self.fca_3(stage3_enc4)
|
|
|
|
stage1_enc2_decoder = self.enc2_1_reduce(stage1_enc2)
|
|
stage2_enc2_docoder = F.interpolate(self.enc2_2_reduce(stage2_enc2), scale_factor=2,
|
|
mode='bilinear', align_corners=True)
|
|
stage3_enc2_decoder = F.interpolate(self.enc2_3_reduce(stage3_enc2), scale_factor=4,
|
|
mode='bilinear', align_corners=True)
|
|
fusion = stage1_enc2_decoder + stage2_enc2_docoder + stage3_enc2_decoder
|
|
fusion1 = self.conv_fusion(fusion)
|
|
|
|
stage1_fca_decoder = F.interpolate(self.fca_1_reduce(stage1_fca), scale_factor=4,
|
|
mode='bilinear', align_corners=True)
|
|
stage2_fca_decoder = F.interpolate(self.fca_2_reduce(stage2_fca), scale_factor=8,
|
|
mode='bilinear', align_corners=True)
|
|
stage3_fca_decoder = F.interpolate(self.fca_3_reduce(stage3_fca), scale_factor=16,
|
|
mode='bilinear', align_corners=True)
|
|
#print(fusion.shape,stage1_fca_decoder.shape,stage2_fca_decoder.shape,stage3_fca_decoder.shape)
|
|
fusion2 = fusion1 + stage1_fca_decoder + stage2_fca_decoder + stage3_fca_decoder
|
|
|
|
outputs = list()
|
|
out = self.conv_out(fusion2)
|
|
out1 = F.interpolate(out, scale_factor=4, mode='bilinear', align_corners=True)
|
|
outputs.append(out1)
|
|
|
|
#return tuple(outputs)
|
|
return outputs[0]
|
|
|
|
def get_dfanet(dataset='citys', backbone='', pretrained=False, root='~/.torch/models',
|
|
pretrained_base=True, **kwargs):
|
|
acronyms = {
|
|
'pascal_voc': 'pascal_voc',
|
|
'pascal_aug': 'pascal_aug',
|
|
'ade20k': 'ade',
|
|
'coco': 'coco',
|
|
'citys': 'citys',
|
|
}
|
|
from ..data.dataloader import datasets
|
|
model = DFANet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
|
|
if pretrained:
|
|
from .model_store import get_model_file
|
|
device = torch.device(kwargs['local_rank'])
|
|
model.load_state_dict(torch.load(get_model_file('dfanet_%s' % (acronyms[dataset]), root=root),
|
|
map_location=device))
|
|
return model
|
|
|
|
|
|
def get_dfanet_citys(**kwargs):
|
|
return get_dfanet('citys', **kwargs)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
#model = get_dfanet_citys()
|
|
input = torch.rand(2, 3, 512, 512)
|
|
model = DFANet(4, pretrained_base=False)
|
|
# target = torch.zeros(4, 512, 512).cuda()
|
|
# model.eval()
|
|
# print(model)
|
|
loss = model(input)
|
|
print(loss, loss.shape)
|
|
|
|
# from torchsummary import summary
|
|
#
|
|
# summary(model, (3, 224, 224)) # 打印表格,按顺序输出每层的输出形状和参数
|
|
import torch
|
|
from thop import profile
|
|
from torchsummary import summary
|
|
|
|
flop, params = profile(model, input_size=(1, 3, 512, 512))
|
|
print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop / 1e9, params / 1e6))
|