244 lines
9.0 KiB
Python
244 lines
9.0 KiB
Python
"""Efficient Neural Network"""
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
__all__ = ['ENet', 'get_enet', 'get_enet_citys']
|
|
|
|
|
|
class ENet(nn.Module):
|
|
"""Efficient Neural Network"""
|
|
|
|
def __init__(self, nclass, backbone='', aux=False, jpu=False, pretrained_base=None, **kwargs):
|
|
super(ENet, self).__init__()
|
|
self.initial = InitialBlock(13, **kwargs)
|
|
|
|
self.bottleneck1_0 = Bottleneck(16, 16, 64, downsampling=True, **kwargs)
|
|
self.bottleneck1_1 = Bottleneck(64, 16, 64, **kwargs)
|
|
self.bottleneck1_2 = Bottleneck(64, 16, 64, **kwargs)
|
|
self.bottleneck1_3 = Bottleneck(64, 16, 64, **kwargs)
|
|
self.bottleneck1_4 = Bottleneck(64, 16, 64, **kwargs)
|
|
|
|
self.bottleneck2_0 = Bottleneck(64, 32, 128, downsampling=True, **kwargs)
|
|
self.bottleneck2_1 = Bottleneck(128, 32, 128, **kwargs)
|
|
self.bottleneck2_2 = Bottleneck(128, 32, 128, dilation=2, **kwargs)
|
|
self.bottleneck2_3 = Bottleneck(128, 32, 128, asymmetric=True, **kwargs)
|
|
self.bottleneck2_4 = Bottleneck(128, 32, 128, dilation=4, **kwargs)
|
|
self.bottleneck2_5 = Bottleneck(128, 32, 128, **kwargs)
|
|
self.bottleneck2_6 = Bottleneck(128, 32, 128, dilation=8, **kwargs)
|
|
self.bottleneck2_7 = Bottleneck(128, 32, 128, asymmetric=True, **kwargs)
|
|
self.bottleneck2_8 = Bottleneck(128, 32, 128, dilation=16, **kwargs)
|
|
|
|
self.bottleneck3_1 = Bottleneck(128, 32, 128, **kwargs)
|
|
self.bottleneck3_2 = Bottleneck(128, 32, 128, dilation=2, **kwargs)
|
|
self.bottleneck3_3 = Bottleneck(128, 32, 128, asymmetric=True, **kwargs)
|
|
self.bottleneck3_4 = Bottleneck(128, 32, 128, dilation=4, **kwargs)
|
|
self.bottleneck3_5 = Bottleneck(128, 32, 128, **kwargs)
|
|
self.bottleneck3_6 = Bottleneck(128, 32, 128, dilation=8, **kwargs)
|
|
self.bottleneck3_7 = Bottleneck(128, 32, 128, asymmetric=True, **kwargs)
|
|
self.bottleneck3_8 = Bottleneck(128, 32, 128, dilation=16, **kwargs)
|
|
|
|
self.bottleneck4_0 = UpsamplingBottleneck(128, 16, 64, **kwargs)
|
|
self.bottleneck4_1 = Bottleneck(64, 16, 64, **kwargs)
|
|
self.bottleneck4_2 = Bottleneck(64, 16, 64, **kwargs)
|
|
|
|
self.bottleneck5_0 = UpsamplingBottleneck(64, 4, 16, **kwargs)
|
|
self.bottleneck5_1 = Bottleneck(16, 4, 16, **kwargs)
|
|
|
|
self.fullconv = nn.ConvTranspose2d(16, nclass, 2, 2, bias=False)
|
|
|
|
self.__setattr__('exclusive', ['bottleneck1_0', 'bottleneck1_1', 'bottleneck1_2', 'bottleneck1_3',
|
|
'bottleneck1_4', 'bottleneck2_0', 'bottleneck2_1', 'bottleneck2_2',
|
|
'bottleneck2_3', 'bottleneck2_4', 'bottleneck2_5', 'bottleneck2_6',
|
|
'bottleneck2_7', 'bottleneck2_8', 'bottleneck3_1', 'bottleneck3_2',
|
|
'bottleneck3_3', 'bottleneck3_4', 'bottleneck3_5', 'bottleneck3_6',
|
|
'bottleneck3_7', 'bottleneck3_8', 'bottleneck4_0', 'bottleneck4_1',
|
|
'bottleneck4_2', 'bottleneck5_0', 'bottleneck5_1', 'fullconv'])
|
|
|
|
def forward(self, x):
|
|
# init
|
|
x = self.initial(x)
|
|
|
|
# stage 1
|
|
x, max_indices1 = self.bottleneck1_0(x)
|
|
x = self.bottleneck1_1(x)
|
|
x = self.bottleneck1_2(x)
|
|
x = self.bottleneck1_3(x)
|
|
x = self.bottleneck1_4(x)
|
|
|
|
# stage 2
|
|
x, max_indices2 = self.bottleneck2_0(x)
|
|
x = self.bottleneck2_1(x)
|
|
x = self.bottleneck2_2(x)
|
|
x = self.bottleneck2_3(x)
|
|
x = self.bottleneck2_4(x)
|
|
x = self.bottleneck2_5(x)
|
|
x = self.bottleneck2_6(x)
|
|
x = self.bottleneck2_7(x)
|
|
x = self.bottleneck2_8(x)
|
|
|
|
# stage 3
|
|
x = self.bottleneck3_1(x)
|
|
x = self.bottleneck3_2(x)
|
|
x = self.bottleneck3_3(x)
|
|
x = self.bottleneck3_4(x)
|
|
x = self.bottleneck3_6(x)
|
|
x = self.bottleneck3_7(x)
|
|
x = self.bottleneck3_8(x)
|
|
|
|
# stage 4
|
|
x = self.bottleneck4_0(x, max_indices2)
|
|
x = self.bottleneck4_1(x)
|
|
x = self.bottleneck4_2(x)
|
|
|
|
# stage 5
|
|
x = self.bottleneck5_0(x, max_indices1)
|
|
x = self.bottleneck5_1(x)
|
|
|
|
# out
|
|
x = self.fullconv(x)
|
|
return tuple([x])
|
|
|
|
|
|
class InitialBlock(nn.Module):
|
|
"""ENet initial block"""
|
|
|
|
def __init__(self, out_channels, norm_layer=nn.BatchNorm2d, **kwargs):
|
|
super(InitialBlock, self).__init__()
|
|
self.conv = nn.Conv2d(3, out_channels, 3, 2, 1, bias=False)
|
|
self.maxpool = nn.MaxPool2d(2, 2)
|
|
self.bn = norm_layer(out_channels + 3)
|
|
self.act = nn.PReLU()
|
|
|
|
def forward(self, x):
|
|
x_conv = self.conv(x)
|
|
x_pool = self.maxpool(x)
|
|
x = torch.cat([x_conv, x_pool], dim=1)
|
|
x = self.bn(x)
|
|
x = self.act(x)
|
|
return x
|
|
|
|
|
|
class Bottleneck(nn.Module):
|
|
"""Bottlenecks include regular, asymmetric, downsampling, dilated"""
|
|
|
|
def __init__(self, in_channels, inter_channels, out_channels, dilation=1, asymmetric=False,
|
|
downsampling=False, norm_layer=nn.BatchNorm2d, **kwargs):
|
|
super(Bottleneck, self).__init__()
|
|
self.downsamping = downsampling
|
|
if downsampling:
|
|
self.maxpool = nn.MaxPool2d(2, 2, return_indices=True)
|
|
self.conv_down = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
|
norm_layer(out_channels)
|
|
)
|
|
|
|
self.conv1 = nn.Sequential(
|
|
nn.Conv2d(in_channels, inter_channels, 1, bias=False),
|
|
norm_layer(inter_channels),
|
|
nn.PReLU()
|
|
)
|
|
|
|
if downsampling:
|
|
self.conv2 = nn.Sequential(
|
|
nn.Conv2d(inter_channels, inter_channels, 2, stride=2, bias=False),
|
|
norm_layer(inter_channels),
|
|
nn.PReLU()
|
|
)
|
|
else:
|
|
if asymmetric:
|
|
self.conv2 = nn.Sequential(
|
|
nn.Conv2d(inter_channels, inter_channels, (5, 1), padding=(2, 0), bias=False),
|
|
nn.Conv2d(inter_channels, inter_channels, (1, 5), padding=(0, 2), bias=False),
|
|
norm_layer(inter_channels),
|
|
nn.PReLU()
|
|
)
|
|
else:
|
|
self.conv2 = nn.Sequential(
|
|
nn.Conv2d(inter_channels, inter_channels, 3, dilation=dilation, padding=dilation, bias=False),
|
|
norm_layer(inter_channels),
|
|
nn.PReLU()
|
|
)
|
|
self.conv3 = nn.Sequential(
|
|
nn.Conv2d(inter_channels, out_channels, 1, bias=False),
|
|
norm_layer(out_channels),
|
|
nn.Dropout2d(0.1)
|
|
)
|
|
self.act = nn.PReLU()
|
|
|
|
def forward(self, x):
|
|
identity = x
|
|
if self.downsamping:
|
|
identity, max_indices = self.maxpool(identity)
|
|
identity = self.conv_down(identity)
|
|
|
|
out = self.conv1(x)
|
|
out = self.conv2(out)
|
|
out = self.conv3(out)
|
|
out = self.act(out + identity)
|
|
|
|
if self.downsamping:
|
|
return out, max_indices
|
|
else:
|
|
return out
|
|
|
|
|
|
class UpsamplingBottleneck(nn.Module):
|
|
"""upsampling Block"""
|
|
|
|
def __init__(self, in_channels, inter_channels, out_channels, norm_layer=nn.BatchNorm2d, **kwargs):
|
|
super(UpsamplingBottleneck, self).__init__()
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
|
norm_layer(out_channels)
|
|
)
|
|
self.upsampling = nn.MaxUnpool2d(2)
|
|
|
|
self.block = nn.Sequential(
|
|
nn.Conv2d(in_channels, inter_channels, 1, bias=False),
|
|
norm_layer(inter_channels),
|
|
nn.PReLU(),
|
|
nn.ConvTranspose2d(inter_channels, inter_channels, 2, 2, bias=False),
|
|
norm_layer(inter_channels),
|
|
nn.PReLU(),
|
|
nn.Conv2d(inter_channels, out_channels, 1, bias=False),
|
|
norm_layer(out_channels),
|
|
nn.Dropout2d(0.1)
|
|
)
|
|
self.act = nn.PReLU()
|
|
|
|
def forward(self, x, max_indices):
|
|
out_up = self.conv(x)
|
|
out_up = self.upsampling(out_up, max_indices)
|
|
|
|
out_ext = self.block(x)
|
|
out = self.act(out_up + out_ext)
|
|
return out
|
|
|
|
|
|
def get_enet(dataset='citys', backbone='', pretrained=False, root='~/.torch/models', pretrained_base=True, **kwargs):
|
|
acronyms = {
|
|
'pascal_voc': 'pascal_voc',
|
|
'pascal_aug': 'pascal_aug',
|
|
'ade20k': 'ade',
|
|
'coco': 'coco',
|
|
'citys': 'citys',
|
|
}
|
|
from core.data.dataloader import datasets
|
|
model = ENet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
|
|
if pretrained:
|
|
from .model_store import get_model_file
|
|
device = torch.device(kwargs['local_rank'])
|
|
model.load_state_dict(torch.load(get_model_file('enet_%s' % (acronyms[dataset]), root=root),
|
|
map_location=device))
|
|
return model
|
|
|
|
|
|
def get_enet_citys(**kwargs):
|
|
return get_enet('citys', '', **kwargs)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
img = torch.randn(1, 3, 512, 512)
|
|
model = get_enet_citys()
|
|
output = model(img)
|