AIlib2/segutils/core/models/base_models/mobilenetv2.py

159 lines
5.1 KiB
Python

"""MobileNet and MobileNetV2."""
import torch
import torch.nn as nn
from core.nn import _ConvBNReLU, _DepthwiseConv, InvertedResidual
__all__ = ['MobileNet', 'MobileNetV2', 'get_mobilenet', 'get_mobilenet_v2',
'mobilenet1_0', 'mobilenet_v2_1_0', 'mobilenet0_75', 'mobilenet_v2_0_75',
'mobilenet0_5', 'mobilenet_v2_0_5', 'mobilenet0_25', 'mobilenet_v2_0_25']
class MobileNet(nn.Module):
def __init__(self, num_classes=1000, multiplier=1.0, norm_layer=nn.BatchNorm2d, **kwargs):
super(MobileNet, self).__init__()
conv_dw_setting = [
[64, 1, 1],
[128, 2, 2],
[256, 2, 2],
[512, 6, 2],
[1024, 2, 2]]
input_channels = int(32 * multiplier) if multiplier > 1.0 else 32
features = [_ConvBNReLU(3, input_channels, 3, 2, 1, norm_layer=norm_layer)]
for c, n, s in conv_dw_setting:
out_channels = int(c * multiplier)
for i in range(n):
stride = s if i == 0 else 1
features.append(_DepthwiseConv(input_channels, out_channels, stride, norm_layer))
input_channels = out_channels
features.append(nn.AdaptiveAvgPool2d(1))
self.features = nn.Sequential(*features)
self.classifier = nn.Linear(int(1024 * multiplier), num_classes)
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0), x.size(1)))
return x
class MobileNetV2(nn.Module):
def __init__(self, num_classes=1000, multiplier=1.0, norm_layer=nn.BatchNorm2d, **kwargs):
super(MobileNetV2, self).__init__()
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1]]
# building first layer
input_channels = int(32 * multiplier) if multiplier > 1.0 else 32
last_channels = int(1280 * multiplier) if multiplier > 1.0 else 1280
features = [_ConvBNReLU(3, input_channels, 3, 2, 1, relu6=True, norm_layer=norm_layer)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
out_channels = int(c * multiplier)
for i in range(n):
stride = s if i == 0 else 1
features.append(InvertedResidual(input_channels, out_channels, stride, t, norm_layer))
input_channels = out_channels
# building last several layers
features.append(_ConvBNReLU(input_channels, last_channels, 1, relu6=True, norm_layer=norm_layer))
features.append(nn.AdaptiveAvgPool2d(1))
self.features = nn.Sequential(*features)
self.classifier = nn.Sequential(
nn.Dropout2d(0.2),
nn.Linear(last_channels, num_classes))
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0), x.size(1)))
return x
# Constructor
def get_mobilenet(multiplier=1.0, pretrained=False, root='~/.torch/models', **kwargs):
model = MobileNet(multiplier=multiplier, **kwargs)
if pretrained:
raise ValueError("Not support pretrained")
return model
def get_mobilenet_v2(multiplier=1.0, pretrained=False, root='~/.torch/models', **kwargs):
model = MobileNetV2(multiplier=multiplier, **kwargs)
if pretrained:
raise ValueError("Not support pretrained")
return model
def mobilenet1_0(**kwargs):
return get_mobilenet(1.0, **kwargs)
def mobilenet_v2_1_0(**kwargs):
return get_mobilenet_v2(1.0, **kwargs)
def mobilenet0_75(**kwargs):
return get_mobilenet(0.75, **kwargs)
def mobilenet_v2_0_75(**kwargs):
return get_mobilenet_v2(0.75, **kwargs)
def mobilenet0_5(**kwargs):
return get_mobilenet(0.5, **kwargs)
def mobilenet_v2_0_5(**kwargs):
return get_mobilenet_v2(0.5, **kwargs)
def mobilenet0_25(**kwargs):
return get_mobilenet(0.25, **kwargs)
def mobilenet_v2_0_25(**kwargs):
return get_mobilenet_v2(0.25, **kwargs)
if __name__ == '__main__':
model = mobilenet0_5()