|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- """Image Cascade Network"""
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- from core.models.segbase import SegBaseModel
-
- __all__ = ['ICNet', 'get_icnet', 'get_icnet_resnet50_citys',
- 'get_icnet_resnet101_citys', 'get_icnet_resnet152_citys']
-
-
- class ICNet(SegBaseModel):
- """Image Cascade Network"""
-
- def __init__(self, nclass, backbone='resnet50', aux=False, jpu=False, pretrained_base=True, **kwargs):
- super(ICNet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
- self.conv_sub1 = nn.Sequential(
- _ConvBNReLU(3, 32, 3, 2, **kwargs),
- _ConvBNReLU(32, 32, 3, 2, **kwargs),
- _ConvBNReLU(32, 64, 3, 2, **kwargs)
- )
-
- self.ppm = PyramidPoolingModule()
-
- self.head = _ICHead(nclass, **kwargs)
-
- self.__setattr__('exclusive', ['conv_sub1', 'head'])
-
- def forward(self, x):
- # sub 1
- x_sub1 = self.conv_sub1(x)
-
- # sub 2
- x_sub2 = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True)
- _, x_sub2, _, _ = self.base_forward(x_sub2)
-
- # sub 4
- x_sub4 = F.interpolate(x, scale_factor=0.25, mode='bilinear', align_corners=True)
- _, _, _, x_sub4 = self.base_forward(x_sub4)
- # add PyramidPoolingModule
- x_sub4 = self.ppm(x_sub4)
- outputs = self.head(x_sub1, x_sub2, x_sub4)
-
- return tuple(outputs)
-
- class PyramidPoolingModule(nn.Module):
- def __init__(self, pyramids=[1,2,3,6]):
- super(PyramidPoolingModule, self).__init__()
- self.pyramids = pyramids
-
- def forward(self, input):
- feat = input
- height, width = input.shape[2:]
- for bin_size in self.pyramids:
- x = F.adaptive_avg_pool2d(input, output_size=bin_size)
- x = F.interpolate(x, size=(height, width), mode='bilinear', align_corners=True)
- feat = feat + x
- return feat
-
- class _ICHead(nn.Module):
- def __init__(self, nclass, norm_layer=nn.BatchNorm2d, **kwargs):
- super(_ICHead, self).__init__()
- #self.cff_12 = CascadeFeatureFusion(512, 64, 128, nclass, norm_layer, **kwargs)
- self.cff_12 = CascadeFeatureFusion(128, 64, 128, nclass, norm_layer, **kwargs)
- self.cff_24 = CascadeFeatureFusion(2048, 512, 128, nclass, norm_layer, **kwargs)
-
- self.conv_cls = nn.Conv2d(128, nclass, 1, bias=False)
-
- def forward(self, x_sub1, x_sub2, x_sub4):
- outputs = list()
- x_cff_24, x_24_cls = self.cff_24(x_sub4, x_sub2)
- outputs.append(x_24_cls)
- #x_cff_12, x_12_cls = self.cff_12(x_sub2, x_sub1)
- x_cff_12, x_12_cls = self.cff_12(x_cff_24, x_sub1)
- outputs.append(x_12_cls)
-
- up_x2 = F.interpolate(x_cff_12, scale_factor=2, mode='bilinear', align_corners=True)
- up_x2 = self.conv_cls(up_x2)
- outputs.append(up_x2)
- up_x8 = F.interpolate(up_x2, scale_factor=4, mode='bilinear', align_corners=True)
- outputs.append(up_x8)
- # 1 -> 1/4 -> 1/8 -> 1/16
- outputs.reverse()
-
- return outputs
-
-
- class _ConvBNReLU(nn.Module):
- def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1,
- groups=1, norm_layer=nn.BatchNorm2d, bias=False, **kwargs):
- super(_ConvBNReLU, self).__init__()
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
- self.bn = norm_layer(out_channels)
- self.relu = nn.ReLU(True)
-
- def forward(self, x):
- x = self.conv(x)
- x = self.bn(x)
- x = self.relu(x)
- return x
-
-
- class CascadeFeatureFusion(nn.Module):
- """CFF Unit"""
-
- def __init__(self, low_channels, high_channels, out_channels, nclass, norm_layer=nn.BatchNorm2d, **kwargs):
- super(CascadeFeatureFusion, self).__init__()
- self.conv_low = nn.Sequential(
- nn.Conv2d(low_channels, out_channels, 3, padding=2, dilation=2, bias=False),
- norm_layer(out_channels)
- )
- self.conv_high = nn.Sequential(
- nn.Conv2d(high_channels, out_channels, 1, bias=False),
- norm_layer(out_channels)
- )
- self.conv_low_cls = nn.Conv2d(out_channels, nclass, 1, bias=False)
-
- def forward(self, x_low, x_high):
- x_low = F.interpolate(x_low, size=x_high.size()[2:], mode='bilinear', align_corners=True)
- x_low = self.conv_low(x_low)
- x_high = self.conv_high(x_high)
- x = x_low + x_high
- x = F.relu(x, inplace=True)
- x_low_cls = self.conv_low_cls(x_low)
-
- return x, x_low_cls
-
-
- def get_icnet(dataset='citys', backbone='resnet50', pretrained=False, root='~/.torch/models',
- pretrained_base=True, **kwargs):
- acronyms = {
- 'pascal_voc': 'pascal_voc',
- 'pascal_aug': 'pascal_aug',
- 'ade20k': 'ade',
- 'coco': 'coco',
- 'citys': 'citys',
- }
- from ..data.dataloader import datasets
- model = ICNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
- if pretrained:
- from .model_store import get_model_file
- device = torch.device(kwargs['local_rank'])
- model.load_state_dict(torch.load(get_model_file('icnet_%s_%s' % (backbone, acronyms[dataset]), root=root),
- map_location=device))
- return model
-
-
- def get_icnet_resnet50_citys(**kwargs):
- return get_icnet('citys', 'resnet50', **kwargs)
-
-
- def get_icnet_resnet101_citys(**kwargs):
- return get_icnet('citys', 'resnet101', **kwargs)
-
-
- def get_icnet_resnet152_citys(**kwargs):
- return get_icnet('citys', 'resnet152', **kwargs)
-
-
- if __name__ == '__main__':
- # img = torch.randn(1, 3, 256, 256)
- # model = get_icnet_resnet50_citys()
- # outputs = model(img)
- input = torch.rand(2, 3, 224, 224)
- model = ICNet(4, pretrained_base=False)
- # target = torch.zeros(4, 512, 512).cuda()
- # model.eval()
- # print(model)
- loss = model(input)
- #print(loss, loss.shape)
-
- # from torchsummary import summary
- #
- # summary(model, (3, 224, 224)) # 打印表格,按顺序输出每层的输出形状和参数
- import torch
- from thop import profile
- from torchsummary import summary
-
- flop, params = profile(model, input_size=(1, 3, 512, 512))
- print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop / 1e9, params / 1e6))
|