143 lines
4.6 KiB
Python
143 lines
4.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .base_models.xception import get_xception
|
|
from .deeplabv3 import _ASPP
|
|
from .fcn import _FCNHead
|
|
from ..nn import _ConvBNReLU
|
|
|
|
__all__ = ['DeepLabV3Plus', 'get_deeplabv3_plus', 'get_deeplabv3_plus_xception_voc']
|
|
|
|
|
|
class DeepLabV3Plus(nn.Module):
|
|
r"""DeepLabV3Plus
|
|
Parameters
|
|
----------
|
|
nclass : int
|
|
Number of categories for the training dataset.
|
|
backbone : string
|
|
Pre-trained dilated backbone network type (default:'xception').
|
|
norm_layer : object
|
|
Normalization layer used in backbone network (default: :class:`nn.BatchNorm`;
|
|
for Synchronized Cross-GPU BachNormalization).
|
|
aux : bool
|
|
Auxiliary loss.
|
|
|
|
Reference:
|
|
Chen, Liang-Chieh, et al. "Encoder-Decoder with Atrous Separable Convolution for Semantic
|
|
Image Segmentation."
|
|
"""
|
|
|
|
def __init__(self, nclass, backbone='xception', aux=True, pretrained_base=True, dilated=True, **kwargs):
|
|
super(DeepLabV3Plus, self).__init__()
|
|
self.aux = aux
|
|
self.nclass = nclass
|
|
output_stride = 8 if dilated else 32
|
|
|
|
self.pretrained = get_xception(pretrained=pretrained_base, output_stride=output_stride, **kwargs)
|
|
|
|
# deeplabv3 plus
|
|
self.head = _DeepLabHead(nclass, **kwargs)
|
|
if aux:
|
|
self.auxlayer = _FCNHead(728, nclass, **kwargs)
|
|
|
|
def base_forward(self, x):
|
|
# Entry flow
|
|
x = self.pretrained.conv1(x)
|
|
x = self.pretrained.bn1(x)
|
|
x = self.pretrained.relu(x)
|
|
|
|
x = self.pretrained.conv2(x)
|
|
x = self.pretrained.bn2(x)
|
|
x = self.pretrained.relu(x)
|
|
|
|
x = self.pretrained.block1(x)
|
|
# add relu here
|
|
x = self.pretrained.relu(x)
|
|
low_level_feat = x
|
|
|
|
x = self.pretrained.block2(x)
|
|
x = self.pretrained.block3(x)
|
|
|
|
# Middle flow
|
|
x = self.pretrained.midflow(x)
|
|
mid_level_feat = x
|
|
|
|
# Exit flow
|
|
x = self.pretrained.block20(x)
|
|
x = self.pretrained.relu(x)
|
|
x = self.pretrained.conv3(x)
|
|
x = self.pretrained.bn3(x)
|
|
x = self.pretrained.relu(x)
|
|
|
|
x = self.pretrained.conv4(x)
|
|
x = self.pretrained.bn4(x)
|
|
x = self.pretrained.relu(x)
|
|
|
|
x = self.pretrained.conv5(x)
|
|
x = self.pretrained.bn5(x)
|
|
x = self.pretrained.relu(x)
|
|
return low_level_feat, mid_level_feat, x
|
|
|
|
def forward(self, x):
|
|
size = x.size()[2:]
|
|
c1, c3, c4 = self.base_forward(x)
|
|
outputs = list()
|
|
x = self.head(c4, c1)
|
|
x = F.interpolate(x, size, mode='bilinear', align_corners=True)
|
|
outputs.append(x)
|
|
if self.aux:
|
|
auxout = self.auxlayer(c3)
|
|
auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
|
|
outputs.append(auxout)
|
|
return tuple(outputs)
|
|
|
|
|
|
class _DeepLabHead(nn.Module):
|
|
def __init__(self, nclass, c1_channels=128, norm_layer=nn.BatchNorm2d, **kwargs):
|
|
super(_DeepLabHead, self).__init__()
|
|
self.aspp = _ASPP(2048, [12, 24, 36], norm_layer=norm_layer, **kwargs)
|
|
self.c1_block = _ConvBNReLU(c1_channels, 48, 3, padding=1, norm_layer=norm_layer)
|
|
self.block = nn.Sequential(
|
|
_ConvBNReLU(304, 256, 3, padding=1, norm_layer=norm_layer),
|
|
nn.Dropout(0.5),
|
|
_ConvBNReLU(256, 256, 3, padding=1, norm_layer=norm_layer),
|
|
nn.Dropout(0.1),
|
|
nn.Conv2d(256, nclass, 1))
|
|
|
|
def forward(self, x, c1):
|
|
size = c1.size()[2:]
|
|
c1 = self.c1_block(c1)
|
|
x = self.aspp(x)
|
|
x = F.interpolate(x, size, mode='bilinear', align_corners=True)
|
|
return self.block(torch.cat([x, c1], dim=1))
|
|
|
|
|
|
def get_deeplabv3_plus(dataset='pascal_voc', backbone='xception', pretrained=False, root='~/.torch/models',
|
|
pretrained_base=True, **kwargs):
|
|
acronyms = {
|
|
'pascal_voc': 'pascal_voc',
|
|
'pascal_aug': 'pascal_aug',
|
|
'ade20k': 'ade',
|
|
'coco': 'coco',
|
|
'citys': 'citys',
|
|
}
|
|
from ..data.dataloader import datasets
|
|
model = DeepLabV3Plus(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
|
|
if pretrained:
|
|
from .model_store import get_model_file
|
|
device = torch.device(kwargs['local_rank'])
|
|
model.load_state_dict(
|
|
torch.load(get_model_file('deeplabv3_plus_%s_%s' % (backbone, acronyms[dataset]), root=root),
|
|
map_location=device))
|
|
return model
|
|
|
|
|
|
def get_deeplabv3_plus_xception_voc(**kwargs):
|
|
return get_deeplabv3_plus('pascal_voc', 'xception', **kwargs)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
model = get_deeplabv3_plus_xception_voc()
|