AIlib2/segutils/core/models/deeplabv3_plus.py

143 lines
4.6 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from .base_models.xception import get_xception
from .deeplabv3 import _ASPP
from .fcn import _FCNHead
from ..nn import _ConvBNReLU
__all__ = ['DeepLabV3Plus', 'get_deeplabv3_plus', 'get_deeplabv3_plus_xception_voc']
class DeepLabV3Plus(nn.Module):
r"""DeepLabV3Plus
Parameters
----------
nclass : int
Number of categories for the training dataset.
backbone : string
Pre-trained dilated backbone network type (default:'xception').
norm_layer : object
Normalization layer used in backbone network (default: :class:`nn.BatchNorm`;
for Synchronized Cross-GPU BachNormalization).
aux : bool
Auxiliary loss.
Reference:
Chen, Liang-Chieh, et al. "Encoder-Decoder with Atrous Separable Convolution for Semantic
Image Segmentation."
"""
def __init__(self, nclass, backbone='xception', aux=True, pretrained_base=True, dilated=True, **kwargs):
super(DeepLabV3Plus, self).__init__()
self.aux = aux
self.nclass = nclass
output_stride = 8 if dilated else 32
self.pretrained = get_xception(pretrained=pretrained_base, output_stride=output_stride, **kwargs)
# deeplabv3 plus
self.head = _DeepLabHead(nclass, **kwargs)
if aux:
self.auxlayer = _FCNHead(728, nclass, **kwargs)
def base_forward(self, x):
# Entry flow
x = self.pretrained.conv1(x)
x = self.pretrained.bn1(x)
x = self.pretrained.relu(x)
x = self.pretrained.conv2(x)
x = self.pretrained.bn2(x)
x = self.pretrained.relu(x)
x = self.pretrained.block1(x)
# add relu here
x = self.pretrained.relu(x)
low_level_feat = x
x = self.pretrained.block2(x)
x = self.pretrained.block3(x)
# Middle flow
x = self.pretrained.midflow(x)
mid_level_feat = x
# Exit flow
x = self.pretrained.block20(x)
x = self.pretrained.relu(x)
x = self.pretrained.conv3(x)
x = self.pretrained.bn3(x)
x = self.pretrained.relu(x)
x = self.pretrained.conv4(x)
x = self.pretrained.bn4(x)
x = self.pretrained.relu(x)
x = self.pretrained.conv5(x)
x = self.pretrained.bn5(x)
x = self.pretrained.relu(x)
return low_level_feat, mid_level_feat, x
def forward(self, x):
size = x.size()[2:]
c1, c3, c4 = self.base_forward(x)
outputs = list()
x = self.head(c4, c1)
x = F.interpolate(x, size, mode='bilinear', align_corners=True)
outputs.append(x)
if self.aux:
auxout = self.auxlayer(c3)
auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
outputs.append(auxout)
return tuple(outputs)
class _DeepLabHead(nn.Module):
def __init__(self, nclass, c1_channels=128, norm_layer=nn.BatchNorm2d, **kwargs):
super(_DeepLabHead, self).__init__()
self.aspp = _ASPP(2048, [12, 24, 36], norm_layer=norm_layer, **kwargs)
self.c1_block = _ConvBNReLU(c1_channels, 48, 3, padding=1, norm_layer=norm_layer)
self.block = nn.Sequential(
_ConvBNReLU(304, 256, 3, padding=1, norm_layer=norm_layer),
nn.Dropout(0.5),
_ConvBNReLU(256, 256, 3, padding=1, norm_layer=norm_layer),
nn.Dropout(0.1),
nn.Conv2d(256, nclass, 1))
def forward(self, x, c1):
size = c1.size()[2:]
c1 = self.c1_block(c1)
x = self.aspp(x)
x = F.interpolate(x, size, mode='bilinear', align_corners=True)
return self.block(torch.cat([x, c1], dim=1))
def get_deeplabv3_plus(dataset='pascal_voc', backbone='xception', pretrained=False, root='~/.torch/models',
pretrained_base=True, **kwargs):
acronyms = {
'pascal_voc': 'pascal_voc',
'pascal_aug': 'pascal_aug',
'ade20k': 'ade',
'coco': 'coco',
'citys': 'citys',
}
from ..data.dataloader import datasets
model = DeepLabV3Plus(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
device = torch.device(kwargs['local_rank'])
model.load_state_dict(
torch.load(get_model_file('deeplabv3_plus_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model
def get_deeplabv3_plus_xception_voc(**kwargs):
return get_deeplabv3_plus('pascal_voc', 'xception', **kwargs)
if __name__ == '__main__':
model = get_deeplabv3_plus_xception_voc()