import torch import torch.nn as nn import torch.nn.functional as F from core.models.base_models.densenet import * from core.models.fcn import _FCNHead __all__ = ['DenseASPP', 'get_denseaspp', 'get_denseaspp_densenet121_citys', 'get_denseaspp_densenet161_citys', 'get_denseaspp_densenet169_citys', 'get_denseaspp_densenet201_citys'] class DenseASPP(nn.Module): def __init__(self, nclass, backbone='densenet121', aux=False, jpu=False, pretrained_base=True, dilate_scale=8, **kwargs): super(DenseASPP, self).__init__() self.nclass = nclass self.aux = aux self.dilate_scale = dilate_scale if backbone == 'densenet121': self.pretrained = dilated_densenet121(dilate_scale, pretrained=pretrained_base, **kwargs) elif backbone == 'densenet161': self.pretrained = dilated_densenet161(dilate_scale, pretrained=pretrained_base, **kwargs) elif backbone == 'densenet169': self.pretrained = dilated_densenet169(dilate_scale, pretrained=pretrained_base, **kwargs) elif backbone == 'densenet201': self.pretrained = dilated_densenet201(dilate_scale, pretrained=pretrained_base, **kwargs) else: raise RuntimeError('unknown backbone: {}'.format(backbone)) in_channels = self.pretrained.num_features self.head = _DenseASPPHead(in_channels, nclass) if aux: self.auxlayer = _FCNHead(in_channels, nclass, **kwargs) self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head']) def forward(self, x): size = x.size()[2:] #print('size', size) #torch.Size([512, 512]) features = self.pretrained.features(x) #print('22',features.shape) #torch.Size([2, 1024, 64, 64]) if self.dilate_scale > 8: features = F.interpolate(features, scale_factor=2, mode='bilinear', align_corners=True) outputs = [] x = self.head(features) #torch.Size([2, 4, 64, 64]) #print('x.shape',x.shape) x = F.interpolate(x, size, mode='bilinear', align_corners=True)#直接64到512。。。。效果还这么好! outputs.append(x) if self.aux: auxout = self.auxlayer(features) auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) outputs.append(auxout) #return tuple(outputs) return outputs[0] class _DenseASPPHead(nn.Module): def __init__(self, in_channels, nclass, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs): super(_DenseASPPHead, self).__init__() self.dense_aspp_block = _DenseASPPBlock(in_channels, 256, 64, norm_layer, norm_kwargs) self.block = nn.Sequential( nn.Dropout(0.1), nn.Conv2d(in_channels + 5 * 64, nclass, 1) ) def forward(self, x): x = self.dense_aspp_block(x) return self.block(x) class _DenseASPPConv(nn.Sequential): def __init__(self, in_channels, inter_channels, out_channels, atrous_rate, drop_rate=0.1, norm_layer=nn.BatchNorm2d, norm_kwargs=None): super(_DenseASPPConv, self).__init__() self.add_module('conv1', nn.Conv2d(in_channels, inter_channels, 1)), self.add_module('bn1', norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs))), self.add_module('relu1', nn.ReLU(True)), self.add_module('conv2', nn.Conv2d(inter_channels, out_channels, 3, dilation=atrous_rate, padding=atrous_rate)), self.add_module('bn2', norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs))), self.add_module('relu2', nn.ReLU(True)), self.drop_rate = drop_rate def forward(self, x): features = super(_DenseASPPConv, self).forward(x) if self.drop_rate > 0: features = F.dropout(features, p=self.drop_rate, training=self.training) return features class _DenseASPPBlock(nn.Module): def __init__(self, in_channels, inter_channels1, inter_channels2, norm_layer=nn.BatchNorm2d, norm_kwargs=None): super(_DenseASPPBlock, self).__init__() self.aspp_3 = _DenseASPPConv(in_channels, inter_channels1, inter_channels2, 3, 0.1, norm_layer, norm_kwargs) self.aspp_6 = _DenseASPPConv(in_channels + inter_channels2 * 1, inter_channels1, inter_channels2, 6, 0.1, norm_layer, norm_kwargs) self.aspp_12 = _DenseASPPConv(in_channels + inter_channels2 * 2, inter_channels1, inter_channels2, 12, 0.1, norm_layer, norm_kwargs) self.aspp_18 = _DenseASPPConv(in_channels + inter_channels2 * 3, inter_channels1, inter_channels2, 18, 0.1, norm_layer, norm_kwargs) self.aspp_24 = _DenseASPPConv(in_channels + inter_channels2 * 4, inter_channels1, inter_channels2, 24, 0.1, norm_layer, norm_kwargs) def forward(self, x): aspp3 = self.aspp_3(x) x = torch.cat([aspp3, x], dim=1) aspp6 = self.aspp_6(x) x = torch.cat([aspp6, x], dim=1) aspp12 = self.aspp_12(x) x = torch.cat([aspp12, x], dim=1) aspp18 = self.aspp_18(x) x = torch.cat([aspp18, x], dim=1) aspp24 = self.aspp_24(x) x = torch.cat([aspp24, x], dim=1) return x def get_denseaspp(dataset='citys', backbone='densenet121', pretrained=False, root='~/.torch/models', pretrained_base=True, **kwargs): r"""DenseASPP Parameters ---------- dataset : str, default citys The dataset that model pretrained on. (pascal_voc, ade20k) pretrained : bool or str Boolean value controls whether to load the default pretrained weights for model. String value represents the hashtag for a certain version of pretrained weights. root : str, default '~/.torch/models' Location for keeping the model parameters. pretrained_base : bool or str, default True This will load pretrained backbone network, that was trained on ImageNet. Examples -------- # >>> model = get_denseaspp(dataset='citys', backbone='densenet121', pretrained=False) # >>> print(model) """ acronyms = { 'pascal_voc': 'pascal_voc', 'pascal_aug': 'pascal_aug', 'ade20k': 'ade', 'coco': 'coco', 'citys': 'citys', } from ..data.dataloader import datasets model = DenseASPP(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) if pretrained: from .model_store import get_model_file device = torch.device(kwargs['local_rank']) model.load_state_dict(torch.load(get_model_file('denseaspp_%s_%s' % (backbone, acronyms[dataset]), root=root), map_location=device)) return model def get_denseaspp_densenet121_citys(**kwargs): return get_denseaspp('citys', 'densenet121', **kwargs) def get_denseaspp_densenet161_citys(**kwargs): return get_denseaspp('citys', 'densenet161', **kwargs) def get_denseaspp_densenet169_citys(**kwargs): return get_denseaspp('citys', 'densenet169', **kwargs) def get_denseaspp_densenet201_citys(**kwargs): return get_denseaspp('citys', 'densenet201', **kwargs) if __name__ == '__main__': # img = torch.randn(2, 3, 480, 480) # model = get_denseaspp_densenet121_citys() # outputs = model(img) input = torch.rand(2, 3, 512, 512) model = DenseASPP(4, pretrained_base=True) # target = torch.zeros(4, 512, 512).cuda() # model.eval() # print(model) loss = model(input) print(loss, loss.shape) # from torchsummary import summary # # summary(model, (3, 224, 224)) # 打印表格,按顺序输出每层的输出形状和参数 import torch from thop import profile from torchsummary import summary flop, params = profile(model, input_size=(1, 3, 512, 512)) print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop / 1e9, params / 1e6))