AIlib2/segutils/core/models/denseaspp.py

198 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from core.models.base_models.densenet import *
from core.models.fcn import _FCNHead
__all__ = ['DenseASPP', 'get_denseaspp', 'get_denseaspp_densenet121_citys',
'get_denseaspp_densenet161_citys', 'get_denseaspp_densenet169_citys', 'get_denseaspp_densenet201_citys']
class DenseASPP(nn.Module):
def __init__(self, nclass, backbone='densenet121', aux=False, jpu=False,
pretrained_base=True, dilate_scale=8, **kwargs):
super(DenseASPP, self).__init__()
self.nclass = nclass
self.aux = aux
self.dilate_scale = dilate_scale
if backbone == 'densenet121':
self.pretrained = dilated_densenet121(dilate_scale, pretrained=pretrained_base, **kwargs)
elif backbone == 'densenet161':
self.pretrained = dilated_densenet161(dilate_scale, pretrained=pretrained_base, **kwargs)
elif backbone == 'densenet169':
self.pretrained = dilated_densenet169(dilate_scale, pretrained=pretrained_base, **kwargs)
elif backbone == 'densenet201':
self.pretrained = dilated_densenet201(dilate_scale, pretrained=pretrained_base, **kwargs)
else:
raise RuntimeError('unknown backbone: {}'.format(backbone))
in_channels = self.pretrained.num_features
self.head = _DenseASPPHead(in_channels, nclass)
if aux:
self.auxlayer = _FCNHead(in_channels, nclass, **kwargs)
self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
def forward(self, x):
size = x.size()[2:]
#print('size', size) #torch.Size([512, 512])
features = self.pretrained.features(x)
#print('22',features.shape) #torch.Size([2, 1024, 64, 64])
if self.dilate_scale > 8:
features = F.interpolate(features, scale_factor=2, mode='bilinear', align_corners=True)
outputs = []
x = self.head(features) #torch.Size([2, 4, 64, 64])
#print('x.shape',x.shape)
x = F.interpolate(x, size, mode='bilinear', align_corners=True)#直接64到512。。。。效果还这么好
outputs.append(x)
if self.aux:
auxout = self.auxlayer(features)
auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
outputs.append(auxout)
#return tuple(outputs)
return outputs[0]
class _DenseASPPHead(nn.Module):
def __init__(self, in_channels, nclass, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
super(_DenseASPPHead, self).__init__()
self.dense_aspp_block = _DenseASPPBlock(in_channels, 256, 64, norm_layer, norm_kwargs)
self.block = nn.Sequential(
nn.Dropout(0.1),
nn.Conv2d(in_channels + 5 * 64, nclass, 1)
)
def forward(self, x):
x = self.dense_aspp_block(x)
return self.block(x)
class _DenseASPPConv(nn.Sequential):
def __init__(self, in_channels, inter_channels, out_channels, atrous_rate,
drop_rate=0.1, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
super(_DenseASPPConv, self).__init__()
self.add_module('conv1', nn.Conv2d(in_channels, inter_channels, 1)),
self.add_module('bn1', norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs))),
self.add_module('relu1', nn.ReLU(True)),
self.add_module('conv2', nn.Conv2d(inter_channels, out_channels, 3, dilation=atrous_rate, padding=atrous_rate)),
self.add_module('bn2', norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs))),
self.add_module('relu2', nn.ReLU(True)),
self.drop_rate = drop_rate
def forward(self, x):
features = super(_DenseASPPConv, self).forward(x)
if self.drop_rate > 0:
features = F.dropout(features, p=self.drop_rate, training=self.training)
return features
class _DenseASPPBlock(nn.Module):
def __init__(self, in_channels, inter_channels1, inter_channels2,
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
super(_DenseASPPBlock, self).__init__()
self.aspp_3 = _DenseASPPConv(in_channels, inter_channels1, inter_channels2, 3, 0.1,
norm_layer, norm_kwargs)
self.aspp_6 = _DenseASPPConv(in_channels + inter_channels2 * 1, inter_channels1, inter_channels2, 6, 0.1,
norm_layer, norm_kwargs)
self.aspp_12 = _DenseASPPConv(in_channels + inter_channels2 * 2, inter_channels1, inter_channels2, 12, 0.1,
norm_layer, norm_kwargs)
self.aspp_18 = _DenseASPPConv(in_channels + inter_channels2 * 3, inter_channels1, inter_channels2, 18, 0.1,
norm_layer, norm_kwargs)
self.aspp_24 = _DenseASPPConv(in_channels + inter_channels2 * 4, inter_channels1, inter_channels2, 24, 0.1,
norm_layer, norm_kwargs)
def forward(self, x):
aspp3 = self.aspp_3(x)
x = torch.cat([aspp3, x], dim=1)
aspp6 = self.aspp_6(x)
x = torch.cat([aspp6, x], dim=1)
aspp12 = self.aspp_12(x)
x = torch.cat([aspp12, x], dim=1)
aspp18 = self.aspp_18(x)
x = torch.cat([aspp18, x], dim=1)
aspp24 = self.aspp_24(x)
x = torch.cat([aspp24, x], dim=1)
return x
def get_denseaspp(dataset='citys', backbone='densenet121', pretrained=False,
root='~/.torch/models', pretrained_base=True, **kwargs):
r"""DenseASPP
Parameters
----------
dataset : str, default citys
The dataset that model pretrained on. (pascal_voc, ade20k)
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
root : str, default '~/.torch/models'
Location for keeping the model parameters.
pretrained_base : bool or str, default True
This will load pretrained backbone network, that was trained on ImageNet.
Examples
--------
# >>> model = get_denseaspp(dataset='citys', backbone='densenet121', pretrained=False)
# >>> print(model)
"""
acronyms = {
'pascal_voc': 'pascal_voc',
'pascal_aug': 'pascal_aug',
'ade20k': 'ade',
'coco': 'coco',
'citys': 'citys',
}
from ..data.dataloader import datasets
model = DenseASPP(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('denseaspp_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model
def get_denseaspp_densenet121_citys(**kwargs):
return get_denseaspp('citys', 'densenet121', **kwargs)
def get_denseaspp_densenet161_citys(**kwargs):
return get_denseaspp('citys', 'densenet161', **kwargs)
def get_denseaspp_densenet169_citys(**kwargs):
return get_denseaspp('citys', 'densenet169', **kwargs)
def get_denseaspp_densenet201_citys(**kwargs):
return get_denseaspp('citys', 'densenet201', **kwargs)
if __name__ == '__main__':
# img = torch.randn(2, 3, 480, 480)
# model = get_denseaspp_densenet121_citys()
# outputs = model(img)
input = torch.rand(2, 3, 512, 512)
model = DenseASPP(4, pretrained_base=True)
# target = torch.zeros(4, 512, 512).cuda()
# model.eval()
# print(model)
loss = model(input)
print(loss, loss.shape)
# from torchsummary import summary
#
# summary(model, (3, 224, 224)) # 打印表格,按顺序输出每层的输出形状和参数
import torch
from thop import profile
from torchsummary import summary
flop, params = profile(model, input_size=(1, 3, 512, 512))
print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop / 1e9, params / 1e6))