|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- from core.models.base_models.densenet import *
- from core.models.fcn import _FCNHead
-
- __all__ = ['DenseASPP', 'get_denseaspp', 'get_denseaspp_densenet121_citys',
- 'get_denseaspp_densenet161_citys', 'get_denseaspp_densenet169_citys', 'get_denseaspp_densenet201_citys']
-
-
- class DenseASPP(nn.Module):
- def __init__(self, nclass, backbone='densenet121', aux=False, jpu=False,
- pretrained_base=True, dilate_scale=8, **kwargs):
- super(DenseASPP, self).__init__()
- self.nclass = nclass
- self.aux = aux
- self.dilate_scale = dilate_scale
- if backbone == 'densenet121':
- self.pretrained = dilated_densenet121(dilate_scale, pretrained=pretrained_base, **kwargs)
- elif backbone == 'densenet161':
- self.pretrained = dilated_densenet161(dilate_scale, pretrained=pretrained_base, **kwargs)
- elif backbone == 'densenet169':
- self.pretrained = dilated_densenet169(dilate_scale, pretrained=pretrained_base, **kwargs)
- elif backbone == 'densenet201':
- self.pretrained = dilated_densenet201(dilate_scale, pretrained=pretrained_base, **kwargs)
- else:
- raise RuntimeError('unknown backbone: {}'.format(backbone))
- in_channels = self.pretrained.num_features
-
- self.head = _DenseASPPHead(in_channels, nclass)
-
- if aux:
- self.auxlayer = _FCNHead(in_channels, nclass, **kwargs)
-
- self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
-
- def forward(self, x):
- size = x.size()[2:]
- #print('size', size) #torch.Size([512, 512])
- features = self.pretrained.features(x)
- #print('22',features.shape) #torch.Size([2, 1024, 64, 64])
- if self.dilate_scale > 8:
- features = F.interpolate(features, scale_factor=2, mode='bilinear', align_corners=True)
- outputs = []
- x = self.head(features) #torch.Size([2, 4, 64, 64])
- #print('x.shape',x.shape)
- x = F.interpolate(x, size, mode='bilinear', align_corners=True)#直接64到512。。。。效果还这么好!
- outputs.append(x)
-
- if self.aux:
- auxout = self.auxlayer(features)
- auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
- outputs.append(auxout)
- #return tuple(outputs)
- return outputs[0]
-
- class _DenseASPPHead(nn.Module):
- def __init__(self, in_channels, nclass, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
- super(_DenseASPPHead, self).__init__()
- self.dense_aspp_block = _DenseASPPBlock(in_channels, 256, 64, norm_layer, norm_kwargs)
- self.block = nn.Sequential(
- nn.Dropout(0.1),
- nn.Conv2d(in_channels + 5 * 64, nclass, 1)
- )
-
- def forward(self, x):
- x = self.dense_aspp_block(x)
- return self.block(x)
-
-
- class _DenseASPPConv(nn.Sequential):
- def __init__(self, in_channels, inter_channels, out_channels, atrous_rate,
- drop_rate=0.1, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
- super(_DenseASPPConv, self).__init__()
- self.add_module('conv1', nn.Conv2d(in_channels, inter_channels, 1)),
- self.add_module('bn1', norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs))),
- self.add_module('relu1', nn.ReLU(True)),
- self.add_module('conv2', nn.Conv2d(inter_channels, out_channels, 3, dilation=atrous_rate, padding=atrous_rate)),
- self.add_module('bn2', norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs))),
- self.add_module('relu2', nn.ReLU(True)),
- self.drop_rate = drop_rate
-
- def forward(self, x):
- features = super(_DenseASPPConv, self).forward(x)
- if self.drop_rate > 0:
- features = F.dropout(features, p=self.drop_rate, training=self.training)
- return features
-
-
- class _DenseASPPBlock(nn.Module):
- def __init__(self, in_channels, inter_channels1, inter_channels2,
- norm_layer=nn.BatchNorm2d, norm_kwargs=None):
- super(_DenseASPPBlock, self).__init__()
- self.aspp_3 = _DenseASPPConv(in_channels, inter_channels1, inter_channels2, 3, 0.1,
- norm_layer, norm_kwargs)
- self.aspp_6 = _DenseASPPConv(in_channels + inter_channels2 * 1, inter_channels1, inter_channels2, 6, 0.1,
- norm_layer, norm_kwargs)
- self.aspp_12 = _DenseASPPConv(in_channels + inter_channels2 * 2, inter_channels1, inter_channels2, 12, 0.1,
- norm_layer, norm_kwargs)
- self.aspp_18 = _DenseASPPConv(in_channels + inter_channels2 * 3, inter_channels1, inter_channels2, 18, 0.1,
- norm_layer, norm_kwargs)
- self.aspp_24 = _DenseASPPConv(in_channels + inter_channels2 * 4, inter_channels1, inter_channels2, 24, 0.1,
- norm_layer, norm_kwargs)
-
- def forward(self, x):
- aspp3 = self.aspp_3(x)
- x = torch.cat([aspp3, x], dim=1)
-
- aspp6 = self.aspp_6(x)
- x = torch.cat([aspp6, x], dim=1)
-
- aspp12 = self.aspp_12(x)
- x = torch.cat([aspp12, x], dim=1)
-
- aspp18 = self.aspp_18(x)
- x = torch.cat([aspp18, x], dim=1)
-
- aspp24 = self.aspp_24(x)
- x = torch.cat([aspp24, x], dim=1)
-
- return x
-
-
- def get_denseaspp(dataset='citys', backbone='densenet121', pretrained=False,
- root='~/.torch/models', pretrained_base=True, **kwargs):
- r"""DenseASPP
-
- Parameters
- ----------
- dataset : str, default citys
- The dataset that model pretrained on. (pascal_voc, ade20k)
- pretrained : bool or str
- Boolean value controls whether to load the default pretrained weights for model.
- String value represents the hashtag for a certain version of pretrained weights.
- root : str, default '~/.torch/models'
- Location for keeping the model parameters.
- pretrained_base : bool or str, default True
- This will load pretrained backbone network, that was trained on ImageNet.
- Examples
- --------
- # >>> model = get_denseaspp(dataset='citys', backbone='densenet121', pretrained=False)
- # >>> print(model)
- """
- acronyms = {
- 'pascal_voc': 'pascal_voc',
- 'pascal_aug': 'pascal_aug',
- 'ade20k': 'ade',
- 'coco': 'coco',
- 'citys': 'citys',
- }
- from ..data.dataloader import datasets
- model = DenseASPP(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
- if pretrained:
- from .model_store import get_model_file
- device = torch.device(kwargs['local_rank'])
- model.load_state_dict(torch.load(get_model_file('denseaspp_%s_%s' % (backbone, acronyms[dataset]), root=root),
- map_location=device))
- return model
-
-
- def get_denseaspp_densenet121_citys(**kwargs):
- return get_denseaspp('citys', 'densenet121', **kwargs)
-
-
- def get_denseaspp_densenet161_citys(**kwargs):
- return get_denseaspp('citys', 'densenet161', **kwargs)
-
-
- def get_denseaspp_densenet169_citys(**kwargs):
- return get_denseaspp('citys', 'densenet169', **kwargs)
-
-
- def get_denseaspp_densenet201_citys(**kwargs):
- return get_denseaspp('citys', 'densenet201', **kwargs)
-
-
- if __name__ == '__main__':
- # img = torch.randn(2, 3, 480, 480)
- # model = get_denseaspp_densenet121_citys()
- # outputs = model(img)
- input = torch.rand(2, 3, 512, 512)
- model = DenseASPP(4, pretrained_base=True)
- # target = torch.zeros(4, 512, 512).cuda()
- # model.eval()
- # print(model)
- loss = model(input)
- print(loss, loss.shape)
-
- # from torchsummary import summary
- #
- # summary(model, (3, 224, 224)) # 打印表格,按顺序输出每层的输出形状和参数
- import torch
- from thop import profile
- from torchsummary import summary
-
- flop, params = profile(model, input_size=(1, 3, 512, 512))
- print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop / 1e9, params / 1e6))
|