|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243 |
- """Efficient Neural Network"""
- import torch
- import torch.nn as nn
-
- __all__ = ['ENet', 'get_enet', 'get_enet_citys']
-
-
- class ENet(nn.Module):
- """Efficient Neural Network"""
-
- def __init__(self, nclass, backbone='', aux=False, jpu=False, pretrained_base=None, **kwargs):
- super(ENet, self).__init__()
- self.initial = InitialBlock(13, **kwargs)
-
- self.bottleneck1_0 = Bottleneck(16, 16, 64, downsampling=True, **kwargs)
- self.bottleneck1_1 = Bottleneck(64, 16, 64, **kwargs)
- self.bottleneck1_2 = Bottleneck(64, 16, 64, **kwargs)
- self.bottleneck1_3 = Bottleneck(64, 16, 64, **kwargs)
- self.bottleneck1_4 = Bottleneck(64, 16, 64, **kwargs)
-
- self.bottleneck2_0 = Bottleneck(64, 32, 128, downsampling=True, **kwargs)
- self.bottleneck2_1 = Bottleneck(128, 32, 128, **kwargs)
- self.bottleneck2_2 = Bottleneck(128, 32, 128, dilation=2, **kwargs)
- self.bottleneck2_3 = Bottleneck(128, 32, 128, asymmetric=True, **kwargs)
- self.bottleneck2_4 = Bottleneck(128, 32, 128, dilation=4, **kwargs)
- self.bottleneck2_5 = Bottleneck(128, 32, 128, **kwargs)
- self.bottleneck2_6 = Bottleneck(128, 32, 128, dilation=8, **kwargs)
- self.bottleneck2_7 = Bottleneck(128, 32, 128, asymmetric=True, **kwargs)
- self.bottleneck2_8 = Bottleneck(128, 32, 128, dilation=16, **kwargs)
-
- self.bottleneck3_1 = Bottleneck(128, 32, 128, **kwargs)
- self.bottleneck3_2 = Bottleneck(128, 32, 128, dilation=2, **kwargs)
- self.bottleneck3_3 = Bottleneck(128, 32, 128, asymmetric=True, **kwargs)
- self.bottleneck3_4 = Bottleneck(128, 32, 128, dilation=4, **kwargs)
- self.bottleneck3_5 = Bottleneck(128, 32, 128, **kwargs)
- self.bottleneck3_6 = Bottleneck(128, 32, 128, dilation=8, **kwargs)
- self.bottleneck3_7 = Bottleneck(128, 32, 128, asymmetric=True, **kwargs)
- self.bottleneck3_8 = Bottleneck(128, 32, 128, dilation=16, **kwargs)
-
- self.bottleneck4_0 = UpsamplingBottleneck(128, 16, 64, **kwargs)
- self.bottleneck4_1 = Bottleneck(64, 16, 64, **kwargs)
- self.bottleneck4_2 = Bottleneck(64, 16, 64, **kwargs)
-
- self.bottleneck5_0 = UpsamplingBottleneck(64, 4, 16, **kwargs)
- self.bottleneck5_1 = Bottleneck(16, 4, 16, **kwargs)
-
- self.fullconv = nn.ConvTranspose2d(16, nclass, 2, 2, bias=False)
-
- self.__setattr__('exclusive', ['bottleneck1_0', 'bottleneck1_1', 'bottleneck1_2', 'bottleneck1_3',
- 'bottleneck1_4', 'bottleneck2_0', 'bottleneck2_1', 'bottleneck2_2',
- 'bottleneck2_3', 'bottleneck2_4', 'bottleneck2_5', 'bottleneck2_6',
- 'bottleneck2_7', 'bottleneck2_8', 'bottleneck3_1', 'bottleneck3_2',
- 'bottleneck3_3', 'bottleneck3_4', 'bottleneck3_5', 'bottleneck3_6',
- 'bottleneck3_7', 'bottleneck3_8', 'bottleneck4_0', 'bottleneck4_1',
- 'bottleneck4_2', 'bottleneck5_0', 'bottleneck5_1', 'fullconv'])
-
- def forward(self, x):
- # init
- x = self.initial(x)
-
- # stage 1
- x, max_indices1 = self.bottleneck1_0(x)
- x = self.bottleneck1_1(x)
- x = self.bottleneck1_2(x)
- x = self.bottleneck1_3(x)
- x = self.bottleneck1_4(x)
-
- # stage 2
- x, max_indices2 = self.bottleneck2_0(x)
- x = self.bottleneck2_1(x)
- x = self.bottleneck2_2(x)
- x = self.bottleneck2_3(x)
- x = self.bottleneck2_4(x)
- x = self.bottleneck2_5(x)
- x = self.bottleneck2_6(x)
- x = self.bottleneck2_7(x)
- x = self.bottleneck2_8(x)
-
- # stage 3
- x = self.bottleneck3_1(x)
- x = self.bottleneck3_2(x)
- x = self.bottleneck3_3(x)
- x = self.bottleneck3_4(x)
- x = self.bottleneck3_6(x)
- x = self.bottleneck3_7(x)
- x = self.bottleneck3_8(x)
-
- # stage 4
- x = self.bottleneck4_0(x, max_indices2)
- x = self.bottleneck4_1(x)
- x = self.bottleneck4_2(x)
-
- # stage 5
- x = self.bottleneck5_0(x, max_indices1)
- x = self.bottleneck5_1(x)
-
- # out
- x = self.fullconv(x)
- return tuple([x])
-
-
- class InitialBlock(nn.Module):
- """ENet initial block"""
-
- def __init__(self, out_channels, norm_layer=nn.BatchNorm2d, **kwargs):
- super(InitialBlock, self).__init__()
- self.conv = nn.Conv2d(3, out_channels, 3, 2, 1, bias=False)
- self.maxpool = nn.MaxPool2d(2, 2)
- self.bn = norm_layer(out_channels + 3)
- self.act = nn.PReLU()
-
- def forward(self, x):
- x_conv = self.conv(x)
- x_pool = self.maxpool(x)
- x = torch.cat([x_conv, x_pool], dim=1)
- x = self.bn(x)
- x = self.act(x)
- return x
-
-
- class Bottleneck(nn.Module):
- """Bottlenecks include regular, asymmetric, downsampling, dilated"""
-
- def __init__(self, in_channels, inter_channels, out_channels, dilation=1, asymmetric=False,
- downsampling=False, norm_layer=nn.BatchNorm2d, **kwargs):
- super(Bottleneck, self).__init__()
- self.downsamping = downsampling
- if downsampling:
- self.maxpool = nn.MaxPool2d(2, 2, return_indices=True)
- self.conv_down = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, 1, bias=False),
- norm_layer(out_channels)
- )
-
- self.conv1 = nn.Sequential(
- nn.Conv2d(in_channels, inter_channels, 1, bias=False),
- norm_layer(inter_channels),
- nn.PReLU()
- )
-
- if downsampling:
- self.conv2 = nn.Sequential(
- nn.Conv2d(inter_channels, inter_channels, 2, stride=2, bias=False),
- norm_layer(inter_channels),
- nn.PReLU()
- )
- else:
- if asymmetric:
- self.conv2 = nn.Sequential(
- nn.Conv2d(inter_channels, inter_channels, (5, 1), padding=(2, 0), bias=False),
- nn.Conv2d(inter_channels, inter_channels, (1, 5), padding=(0, 2), bias=False),
- norm_layer(inter_channels),
- nn.PReLU()
- )
- else:
- self.conv2 = nn.Sequential(
- nn.Conv2d(inter_channels, inter_channels, 3, dilation=dilation, padding=dilation, bias=False),
- norm_layer(inter_channels),
- nn.PReLU()
- )
- self.conv3 = nn.Sequential(
- nn.Conv2d(inter_channels, out_channels, 1, bias=False),
- norm_layer(out_channels),
- nn.Dropout2d(0.1)
- )
- self.act = nn.PReLU()
-
- def forward(self, x):
- identity = x
- if self.downsamping:
- identity, max_indices = self.maxpool(identity)
- identity = self.conv_down(identity)
-
- out = self.conv1(x)
- out = self.conv2(out)
- out = self.conv3(out)
- out = self.act(out + identity)
-
- if self.downsamping:
- return out, max_indices
- else:
- return out
-
-
- class UpsamplingBottleneck(nn.Module):
- """upsampling Block"""
-
- def __init__(self, in_channels, inter_channels, out_channels, norm_layer=nn.BatchNorm2d, **kwargs):
- super(UpsamplingBottleneck, self).__init__()
- self.conv = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, 1, bias=False),
- norm_layer(out_channels)
- )
- self.upsampling = nn.MaxUnpool2d(2)
-
- self.block = nn.Sequential(
- nn.Conv2d(in_channels, inter_channels, 1, bias=False),
- norm_layer(inter_channels),
- nn.PReLU(),
- nn.ConvTranspose2d(inter_channels, inter_channels, 2, 2, bias=False),
- norm_layer(inter_channels),
- nn.PReLU(),
- nn.Conv2d(inter_channels, out_channels, 1, bias=False),
- norm_layer(out_channels),
- nn.Dropout2d(0.1)
- )
- self.act = nn.PReLU()
-
- def forward(self, x, max_indices):
- out_up = self.conv(x)
- out_up = self.upsampling(out_up, max_indices)
-
- out_ext = self.block(x)
- out = self.act(out_up + out_ext)
- return out
-
-
- def get_enet(dataset='citys', backbone='', pretrained=False, root='~/.torch/models', pretrained_base=True, **kwargs):
- acronyms = {
- 'pascal_voc': 'pascal_voc',
- 'pascal_aug': 'pascal_aug',
- 'ade20k': 'ade',
- 'coco': 'coco',
- 'citys': 'citys',
- }
- from core.data.dataloader import datasets
- model = ENet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
- if pretrained:
- from .model_store import get_model_file
- device = torch.device(kwargs['local_rank'])
- model.load_state_dict(torch.load(get_model_file('enet_%s' % (acronyms[dataset]), root=root),
- map_location=device))
- return model
-
-
- def get_enet_citys(**kwargs):
- return get_enet('citys', '', **kwargs)
-
-
- if __name__ == '__main__':
- img = torch.randn(1, 3, 512, 512)
- model = get_enet_citys()
- output = model(img)
|