"""Decoders Matter for Semantic Segmentation""" import torch import torch.nn as nn import torch.nn.functional as F from core.models.segbase import SegBaseModel from core.models.fcn import _FCNHead __all__ = ['DUNet', 'get_dunet', 'get_dunet_resnet50_pascal_voc', 'get_dunet_resnet101_pascal_voc', 'get_dunet_resnet152_pascal_voc'] # The model may be wrong because lots of details missing in paper. class DUNet(SegBaseModel): """Decoders Matter for Semantic Segmentation Reference: Zhi Tian, Tong He, Chunhua Shen, and Youliang Yan. "Decoders Matter for Semantic Segmentation: Data-Dependent Decoding Enables Flexible Feature Aggregation." CVPR, 2019 """ def __init__(self, nclass, backbone='resnet50', aux=True, pretrained_base=True, **kwargs): super(DUNet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs) self.head = _DUHead(2144, **kwargs) self.dupsample = DUpsampling(256, nclass, scale_factor=8, **kwargs) if aux: self.auxlayer = _FCNHead(1024, 256, **kwargs) self.aux_dupsample = DUpsampling(256, nclass, scale_factor=8, **kwargs) self.__setattr__('exclusive', ['dupsample', 'head', 'auxlayer', 'aux_dupsample'] if aux else ['dupsample', 'head']) def forward(self, x): c1, c2, c3, c4 = self.base_forward(x)#继承自SegBaseModel;返回的是resnet的layer1,2,3,4的输出 outputs = [] x = self.head(c2, c3, c4) x = self.dupsample(x) outputs.append(x) if self.aux: auxout = self.auxlayer(c3) auxout = self.aux_dupsample(auxout) outputs.append(auxout) #return tuple(outputs) return outputs[0] class FeatureFused(nn.Module): """Module for fused features""" def __init__(self, inter_channels=48, norm_layer=nn.BatchNorm2d, **kwargs): super(FeatureFused, self).__init__() self.conv2 = nn.Sequential( nn.Conv2d(512, inter_channels, 1, bias=False), norm_layer(inter_channels), nn.ReLU(True) ) self.conv3 = nn.Sequential( nn.Conv2d(1024, inter_channels, 1, bias=False), norm_layer(inter_channels), nn.ReLU(True) ) def forward(self, c2, c3, c4): size = c4.size()[2:] c2 = self.conv2(F.interpolate(c2, size, mode='bilinear', align_corners=True)) c3 = self.conv3(F.interpolate(c3, size, mode='bilinear', align_corners=True)) fused_feature = torch.cat([c4, c3, c2], dim=1) return fused_feature class _DUHead(nn.Module): def __init__(self, in_channels, norm_layer=nn.BatchNorm2d, **kwargs): super(_DUHead, self).__init__() self.fuse = FeatureFused(norm_layer=norm_layer, **kwargs) self.block = nn.Sequential( nn.Conv2d(in_channels, 256, 3, padding=1, bias=False), norm_layer(256), nn.ReLU(True), nn.Conv2d(256, 256, 3, padding=1, bias=False), norm_layer(256), nn.ReLU(True) ) def forward(self, c2, c3, c4): fused_feature = self.fuse(c2, c3, c4) out = self.block(fused_feature) return out class DUpsampling(nn.Module): """DUsampling module""" def __init__(self, in_channels, out_channels, scale_factor=2, **kwargs): super(DUpsampling, self).__init__() self.scale_factor = scale_factor self.conv_w = nn.Conv2d(in_channels, out_channels * scale_factor * scale_factor, 1, bias=False) def forward(self, x): x = self.conv_w(x) n, c, h, w = x.size() # N, C, H, W --> N, W, H, C x = x.permute(0, 3, 2, 1).contiguous() # N, W, H, C --> N, W, H * scale, C // scale x = x.view(n, w, h * self.scale_factor, c // self.scale_factor) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) x = x.view(n, h * self.scale_factor, w * self.scale_factor, c // (self.scale_factor * self.scale_factor)) # N, H * scale, W * scale, C // (scale ** 2) -- > N, C // (scale ** 2), H * scale, W * scale x = x.permute(0, 3, 1, 2) return x def get_dunet(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~/.torch/models', pretrained_base=True, **kwargs): acronyms = { 'pascal_voc': 'pascal_voc', 'pascal_aug': 'pascal_aug', 'ade20k': 'ade', 'coco': 'coco', 'citys': 'citys', } from ..data.dataloader import datasets model = DUNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) if pretrained: from .model_store import get_model_file device = torch.device(kwargs['local_rank']) model.load_state_dict(torch.load(get_model_file('dunet_%s_%s' % (backbone, acronyms[dataset]), root=root), map_location=device)) return model def get_dunet_resnet50_pascal_voc(**kwargs): return get_dunet('pascal_voc', 'resnet50', **kwargs) def get_dunet_resnet101_pascal_voc(**kwargs): return get_dunet('pascal_voc', 'resnet101', **kwargs) def get_dunet_resnet152_pascal_voc(**kwargs): return get_dunet('pascal_voc', 'resnet152', **kwargs) if __name__ == '__main__': # img = torch.randn(2, 3, 256, 256) # model = get_dunet_resnet50_pascal_voc() # outputs = model(img) input = torch.rand(2, 3, 224, 224) model = DUNet(4, pretrained_base=False) # target = torch.zeros(4, 512, 512).cuda() # model.eval() # print(model) loss = model(input) print(loss, loss.shape) # from torchsummary import summary # # summary(model, (3, 224, 224)) # 打印表格,按顺序输出每层的输出形状和参数 import torch from thop import profile from torchsummary import summary input = torch.randn(1, 3, 512, 512) flop, params = profile(model, inputs=(input, )) print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop / 1e9, params / 1e6))