265 lines
9.5 KiB
Python
265 lines
9.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.utils.model_zoo as model_zoo
|
|
|
|
__all__ = ['ResNetV1b', 'resnet18_v1b', 'resnet34_v1b', 'resnet50_v1b',
|
|
'resnet101_v1b', 'resnet152_v1b', 'resnet152_v1s', 'resnet101_v1s', 'resnet50_v1s']
|
|
|
|
model_urls = {
|
|
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
|
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
|
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
|
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
|
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
|
}
|
|
|
|
|
|
class BasicBlockV1b(nn.Module):
|
|
expansion = 1
|
|
|
|
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
|
|
previous_dilation=1, norm_layer=nn.BatchNorm2d):
|
|
super(BasicBlockV1b, self).__init__()
|
|
self.conv1 = nn.Conv2d(inplanes, planes, 3, stride,
|
|
dilation, dilation, bias=False)
|
|
self.bn1 = norm_layer(planes)
|
|
self.relu = nn.ReLU(True)
|
|
self.conv2 = nn.Conv2d(planes, planes, 3, 1, previous_dilation,
|
|
dilation=previous_dilation, bias=False)
|
|
self.bn2 = norm_layer(planes)
|
|
self.downsample = downsample
|
|
self.stride = stride
|
|
|
|
def forward(self, x):
|
|
identity = x
|
|
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
|
|
if self.downsample is not None:
|
|
identity = self.downsample(x)
|
|
|
|
out += identity
|
|
out = self.relu(out)
|
|
|
|
return out
|
|
|
|
|
|
class BottleneckV1b(nn.Module):
|
|
expansion = 4
|
|
|
|
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
|
|
previous_dilation=1, norm_layer=nn.BatchNorm2d):
|
|
super(BottleneckV1b, self).__init__()
|
|
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
|
self.bn1 = norm_layer(planes)
|
|
self.conv2 = nn.Conv2d(planes, planes, 3, stride,
|
|
dilation, dilation, bias=False)
|
|
self.bn2 = norm_layer(planes)
|
|
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
|
self.bn3 = norm_layer(planes * self.expansion)
|
|
self.relu = nn.ReLU(True)
|
|
self.downsample = downsample
|
|
self.stride = stride
|
|
|
|
def forward(self, x):
|
|
identity = x
|
|
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
out = self.relu(out)
|
|
|
|
out = self.conv3(out)
|
|
out = self.bn3(out)
|
|
|
|
if self.downsample is not None:
|
|
identity = self.downsample(x)
|
|
|
|
out += identity
|
|
out = self.relu(out)
|
|
|
|
return out
|
|
|
|
|
|
class ResNetV1b(nn.Module):
|
|
|
|
def __init__(self, block, layers, num_classes=1000, dilated=True, deep_stem=False,
|
|
zero_init_residual=False, norm_layer=nn.BatchNorm2d):
|
|
self.inplanes = 128 if deep_stem else 64
|
|
super(ResNetV1b, self).__init__()
|
|
if deep_stem:
|
|
self.conv1 = nn.Sequential(
|
|
nn.Conv2d(3, 64, 3, 2, 1, bias=False),
|
|
norm_layer(64),
|
|
nn.ReLU(True),
|
|
nn.Conv2d(64, 64, 3, 1, 1, bias=False),
|
|
norm_layer(64),
|
|
nn.ReLU(True),
|
|
nn.Conv2d(64, 128, 3, 1, 1, bias=False)
|
|
)
|
|
else:
|
|
self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
|
|
self.bn1 = norm_layer(self.inplanes)
|
|
self.relu = nn.ReLU(True)
|
|
self.maxpool = nn.MaxPool2d(3, 2, 1)
|
|
self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)
|
|
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
|
|
if dilated:
|
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, norm_layer=norm_layer)
|
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, norm_layer=norm_layer)
|
|
else:
|
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer)
|
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer)
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
|
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
if zero_init_residual:
|
|
for m in self.modules():
|
|
if isinstance(m, BottleneckV1b):
|
|
nn.init.constant_(m.bn3.weight, 0)
|
|
elif isinstance(m, BasicBlockV1b):
|
|
nn.init.constant_(m.bn2.weight, 0)
|
|
|
|
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):
|
|
downsample = None
|
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
downsample = nn.Sequential(
|
|
nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride, bias=False),
|
|
norm_layer(planes * block.expansion),
|
|
)
|
|
|
|
layers = []
|
|
if dilation in (1, 2):
|
|
layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample,
|
|
previous_dilation=dilation, norm_layer=norm_layer))
|
|
elif dilation == 4:
|
|
layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample,
|
|
previous_dilation=dilation, norm_layer=norm_layer))
|
|
else:
|
|
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
|
|
self.inplanes = planes * block.expansion
|
|
for _ in range(1, blocks):
|
|
layers.append(block(self.inplanes, planes, dilation=dilation,
|
|
previous_dilation=dilation, norm_layer=norm_layer))
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
x = self.maxpool(x)
|
|
|
|
x = self.layer1(x)
|
|
x = self.layer2(x)
|
|
x = self.layer3(x)
|
|
x = self.layer4(x)
|
|
|
|
x = self.avgpool(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.fc(x)
|
|
|
|
return x
|
|
|
|
|
|
def resnet18_v1b(pretrained=False, **kwargs):
|
|
model = ResNetV1b(BasicBlockV1b, [2, 2, 2, 2], **kwargs)
|
|
if pretrained:
|
|
old_dict = model_zoo.load_url(model_urls['resnet18'])
|
|
model_dict = model.state_dict()
|
|
old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
|
|
model_dict.update(old_dict)
|
|
model.load_state_dict(model_dict)
|
|
return model
|
|
|
|
|
|
def resnet34_v1b(pretrained=False, **kwargs):
|
|
model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs)
|
|
if pretrained:
|
|
old_dict = model_zoo.load_url(model_urls['resnet34'])
|
|
model_dict = model.state_dict()
|
|
old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
|
|
model_dict.update(old_dict)
|
|
model.load_state_dict(model_dict)
|
|
return model
|
|
|
|
|
|
def resnet50_v1b(pretrained=False, **kwargs):
|
|
model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], **kwargs)
|
|
if pretrained:
|
|
old_dict = model_zoo.load_url(model_urls['resnet50'])
|
|
model_dict = model.state_dict()
|
|
old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
|
|
model_dict.update(old_dict)
|
|
model.load_state_dict(model_dict)
|
|
return model
|
|
|
|
|
|
def resnet101_v1b(pretrained=False, **kwargs):
|
|
model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], **kwargs)
|
|
if pretrained:
|
|
old_dict = model_zoo.load_url(model_urls['resnet101'])
|
|
model_dict = model.state_dict()
|
|
old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
|
|
model_dict.update(old_dict)
|
|
model.load_state_dict(model_dict)
|
|
return model
|
|
|
|
|
|
def resnet152_v1b(pretrained=False, **kwargs):
|
|
model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], **kwargs)
|
|
if pretrained:
|
|
old_dict = model_zoo.load_url(model_urls['resnet152'])
|
|
model_dict = model.state_dict()
|
|
old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
|
|
model_dict.update(old_dict)
|
|
model.load_state_dict(model_dict)
|
|
return model
|
|
|
|
|
|
def resnet50_v1s(pretrained=False, root='~/.torch/models', **kwargs):
|
|
model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, **kwargs)
|
|
if pretrained:
|
|
from ..model_store import get_resnet_file
|
|
model.load_state_dict(torch.load(get_resnet_file('resnet50', root=root)), strict=False)
|
|
return model
|
|
|
|
|
|
def resnet101_v1s(pretrained=False, root='~/.torch/models', **kwargs):
|
|
model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, **kwargs)
|
|
if pretrained:
|
|
from ..model_store import get_resnet_file
|
|
model.load_state_dict(torch.load(get_resnet_file('resnet101', root=root)), strict=False)
|
|
return model
|
|
|
|
|
|
def resnet152_v1s(pretrained=False, root='~/.torch/models', **kwargs):
|
|
model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, **kwargs)
|
|
if pretrained:
|
|
from ..model_store import get_resnet_file
|
|
model.load_state_dict(torch.load(get_resnet_file('resnet152', root=root)), strict=False)
|
|
return model
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import torch
|
|
|
|
img = torch.randn(4, 3, 224, 224)
|
|
model = resnet50_v1b(True)
|
|
output = model(img)
|