159 lines
5.1 KiB
Python
159 lines
5.1 KiB
Python
"""MobileNet and MobileNetV2."""
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from core.nn import _ConvBNReLU, _DepthwiseConv, InvertedResidual
|
|
|
|
__all__ = ['MobileNet', 'MobileNetV2', 'get_mobilenet', 'get_mobilenet_v2',
|
|
'mobilenet1_0', 'mobilenet_v2_1_0', 'mobilenet0_75', 'mobilenet_v2_0_75',
|
|
'mobilenet0_5', 'mobilenet_v2_0_5', 'mobilenet0_25', 'mobilenet_v2_0_25']
|
|
|
|
|
|
class MobileNet(nn.Module):
|
|
def __init__(self, num_classes=1000, multiplier=1.0, norm_layer=nn.BatchNorm2d, **kwargs):
|
|
super(MobileNet, self).__init__()
|
|
conv_dw_setting = [
|
|
[64, 1, 1],
|
|
[128, 2, 2],
|
|
[256, 2, 2],
|
|
[512, 6, 2],
|
|
[1024, 2, 2]]
|
|
input_channels = int(32 * multiplier) if multiplier > 1.0 else 32
|
|
features = [_ConvBNReLU(3, input_channels, 3, 2, 1, norm_layer=norm_layer)]
|
|
|
|
for c, n, s in conv_dw_setting:
|
|
out_channels = int(c * multiplier)
|
|
for i in range(n):
|
|
stride = s if i == 0 else 1
|
|
features.append(_DepthwiseConv(input_channels, out_channels, stride, norm_layer))
|
|
input_channels = out_channels
|
|
features.append(nn.AdaptiveAvgPool2d(1))
|
|
self.features = nn.Sequential(*features)
|
|
|
|
self.classifier = nn.Linear(int(1024 * multiplier), num_classes)
|
|
|
|
# weight initialization
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.ones_(m.weight)
|
|
nn.init.zeros_(m.bias)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.normal_(m.weight, 0, 0.01)
|
|
nn.init.zeros_(m.bias)
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
x = self.classifier(x.view(x.size(0), x.size(1)))
|
|
return x
|
|
|
|
|
|
class MobileNetV2(nn.Module):
|
|
def __init__(self, num_classes=1000, multiplier=1.0, norm_layer=nn.BatchNorm2d, **kwargs):
|
|
super(MobileNetV2, self).__init__()
|
|
inverted_residual_setting = [
|
|
# t, c, n, s
|
|
[1, 16, 1, 1],
|
|
[6, 24, 2, 2],
|
|
[6, 32, 3, 2],
|
|
[6, 64, 4, 2],
|
|
[6, 96, 3, 1],
|
|
[6, 160, 3, 2],
|
|
[6, 320, 1, 1]]
|
|
# building first layer
|
|
input_channels = int(32 * multiplier) if multiplier > 1.0 else 32
|
|
last_channels = int(1280 * multiplier) if multiplier > 1.0 else 1280
|
|
features = [_ConvBNReLU(3, input_channels, 3, 2, 1, relu6=True, norm_layer=norm_layer)]
|
|
|
|
# building inverted residual blocks
|
|
for t, c, n, s in inverted_residual_setting:
|
|
out_channels = int(c * multiplier)
|
|
for i in range(n):
|
|
stride = s if i == 0 else 1
|
|
features.append(InvertedResidual(input_channels, out_channels, stride, t, norm_layer))
|
|
input_channels = out_channels
|
|
|
|
# building last several layers
|
|
features.append(_ConvBNReLU(input_channels, last_channels, 1, relu6=True, norm_layer=norm_layer))
|
|
features.append(nn.AdaptiveAvgPool2d(1))
|
|
self.features = nn.Sequential(*features)
|
|
|
|
self.classifier = nn.Sequential(
|
|
nn.Dropout2d(0.2),
|
|
nn.Linear(last_channels, num_classes))
|
|
|
|
# weight initialization
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.ones_(m.weight)
|
|
nn.init.zeros_(m.bias)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.normal_(m.weight, 0, 0.01)
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
x = self.classifier(x.view(x.size(0), x.size(1)))
|
|
return x
|
|
|
|
|
|
# Constructor
|
|
def get_mobilenet(multiplier=1.0, pretrained=False, root='~/.torch/models', **kwargs):
|
|
model = MobileNet(multiplier=multiplier, **kwargs)
|
|
|
|
if pretrained:
|
|
raise ValueError("Not support pretrained")
|
|
return model
|
|
|
|
|
|
def get_mobilenet_v2(multiplier=1.0, pretrained=False, root='~/.torch/models', **kwargs):
|
|
model = MobileNetV2(multiplier=multiplier, **kwargs)
|
|
|
|
if pretrained:
|
|
raise ValueError("Not support pretrained")
|
|
return model
|
|
|
|
|
|
def mobilenet1_0(**kwargs):
|
|
return get_mobilenet(1.0, **kwargs)
|
|
|
|
|
|
def mobilenet_v2_1_0(**kwargs):
|
|
return get_mobilenet_v2(1.0, **kwargs)
|
|
|
|
|
|
def mobilenet0_75(**kwargs):
|
|
return get_mobilenet(0.75, **kwargs)
|
|
|
|
|
|
def mobilenet_v2_0_75(**kwargs):
|
|
return get_mobilenet_v2(0.75, **kwargs)
|
|
|
|
|
|
def mobilenet0_5(**kwargs):
|
|
return get_mobilenet(0.5, **kwargs)
|
|
|
|
|
|
def mobilenet_v2_0_5(**kwargs):
|
|
return get_mobilenet_v2(0.5, **kwargs)
|
|
|
|
|
|
def mobilenet0_25(**kwargs):
|
|
return get_mobilenet(0.25, **kwargs)
|
|
|
|
|
|
def mobilenet_v2_0_25(**kwargs):
|
|
return get_mobilenet_v2(0.25, **kwargs)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
model = mobilenet0_5()
|