import torch import torch.nn as nn import torch.nn.functional as F from .base_models.xception import get_xception from .deeplabv3 import _ASPP from .fcn import _FCNHead from ..nn import _ConvBNReLU __all__ = ['DeepLabV3Plus', 'get_deeplabv3_plus', 'get_deeplabv3_plus_xception_voc'] class DeepLabV3Plus(nn.Module): r"""DeepLabV3Plus Parameters ---------- nclass : int Number of categories for the training dataset. backbone : string Pre-trained dilated backbone network type (default:'xception'). norm_layer : object Normalization layer used in backbone network (default: :class:`nn.BatchNorm`; for Synchronized Cross-GPU BachNormalization). aux : bool Auxiliary loss. Reference: Chen, Liang-Chieh, et al. "Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation." """ def __init__(self, nclass, backbone='xception', aux=True, pretrained_base=True, dilated=True, **kwargs): super(DeepLabV3Plus, self).__init__() self.aux = aux self.nclass = nclass output_stride = 8 if dilated else 32 self.pretrained = get_xception(pretrained=pretrained_base, output_stride=output_stride, **kwargs) # deeplabv3 plus self.head = _DeepLabHead(nclass, **kwargs) if aux: self.auxlayer = _FCNHead(728, nclass, **kwargs) def base_forward(self, x): # Entry flow x = self.pretrained.conv1(x) x = self.pretrained.bn1(x) x = self.pretrained.relu(x) x = self.pretrained.conv2(x) x = self.pretrained.bn2(x) x = self.pretrained.relu(x) x = self.pretrained.block1(x) # add relu here x = self.pretrained.relu(x) low_level_feat = x x = self.pretrained.block2(x) x = self.pretrained.block3(x) # Middle flow x = self.pretrained.midflow(x) mid_level_feat = x # Exit flow x = self.pretrained.block20(x) x = self.pretrained.relu(x) x = self.pretrained.conv3(x) x = self.pretrained.bn3(x) x = self.pretrained.relu(x) x = self.pretrained.conv4(x) x = self.pretrained.bn4(x) x = self.pretrained.relu(x) x = self.pretrained.conv5(x) x = self.pretrained.bn5(x) x = self.pretrained.relu(x) return low_level_feat, mid_level_feat, x def forward(self, x): size = x.size()[2:] c1, c3, c4 = self.base_forward(x) outputs = list() x = self.head(c4, c1) x = F.interpolate(x, size, mode='bilinear', align_corners=True) outputs.append(x) if self.aux: auxout = self.auxlayer(c3) auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) outputs.append(auxout) return tuple(outputs) class _DeepLabHead(nn.Module): def __init__(self, nclass, c1_channels=128, norm_layer=nn.BatchNorm2d, **kwargs): super(_DeepLabHead, self).__init__() self.aspp = _ASPP(2048, [12, 24, 36], norm_layer=norm_layer, **kwargs) self.c1_block = _ConvBNReLU(c1_channels, 48, 3, padding=1, norm_layer=norm_layer) self.block = nn.Sequential( _ConvBNReLU(304, 256, 3, padding=1, norm_layer=norm_layer), nn.Dropout(0.5), _ConvBNReLU(256, 256, 3, padding=1, norm_layer=norm_layer), nn.Dropout(0.1), nn.Conv2d(256, nclass, 1)) def forward(self, x, c1): size = c1.size()[2:] c1 = self.c1_block(c1) x = self.aspp(x) x = F.interpolate(x, size, mode='bilinear', align_corners=True) return self.block(torch.cat([x, c1], dim=1)) def get_deeplabv3_plus(dataset='pascal_voc', backbone='xception', pretrained=False, root='~/.torch/models', pretrained_base=True, **kwargs): acronyms = { 'pascal_voc': 'pascal_voc', 'pascal_aug': 'pascal_aug', 'ade20k': 'ade', 'coco': 'coco', 'citys': 'citys', } from ..data.dataloader import datasets model = DeepLabV3Plus(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) if pretrained: from .model_store import get_model_file device = torch.device(kwargs['local_rank']) model.load_state_dict( torch.load(get_model_file('deeplabv3_plus_%s_%s' % (backbone, acronyms[dataset]), root=root), map_location=device)) return model def get_deeplabv3_plus_xception_voc(**kwargs): return get_deeplabv3_plus('pascal_voc', 'xception', **kwargs) if __name__ == '__main__': model = get_deeplabv3_plus_xception_voc()