AIlib2/segutils/core/models/base_models/resnetv1b.py

265 lines
9.5 KiB
Python

import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
__all__ = ['ResNetV1b', 'resnet18_v1b', 'resnet34_v1b', 'resnet50_v1b',
'resnet101_v1b', 'resnet152_v1b', 'resnet152_v1s', 'resnet101_v1s', 'resnet50_v1s']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
class BasicBlockV1b(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
previous_dilation=1, norm_layer=nn.BatchNorm2d):
super(BasicBlockV1b, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, 3, stride,
dilation, dilation, bias=False)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(True)
self.conv2 = nn.Conv2d(planes, planes, 3, 1, previous_dilation,
dilation=previous_dilation, bias=False)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class BottleneckV1b(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
previous_dilation=1, norm_layer=nn.BatchNorm2d):
super(BottleneckV1b, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = norm_layer(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, stride,
dilation, dilation, bias=False)
self.bn2 = norm_layer(planes)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNetV1b(nn.Module):
def __init__(self, block, layers, num_classes=1000, dilated=True, deep_stem=False,
zero_init_residual=False, norm_layer=nn.BatchNorm2d):
self.inplanes = 128 if deep_stem else 64
super(ResNetV1b, self).__init__()
if deep_stem:
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, 3, 2, 1, bias=False),
norm_layer(64),
nn.ReLU(True),
nn.Conv2d(64, 64, 3, 1, 1, bias=False),
norm_layer(64),
nn.ReLU(True),
nn.Conv2d(64, 128, 3, 1, 1, bias=False)
)
else:
self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(True)
self.maxpool = nn.MaxPool2d(3, 2, 1)
self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
if dilated:
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, norm_layer=norm_layer)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, norm_layer=norm_layer)
else:
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if zero_init_residual:
for m in self.modules():
if isinstance(m, BottleneckV1b):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlockV1b):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride, bias=False),
norm_layer(planes * block.expansion),
)
layers = []
if dilation in (1, 2):
layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample,
previous_dilation=dilation, norm_layer=norm_layer))
elif dilation == 4:
layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample,
previous_dilation=dilation, norm_layer=norm_layer))
else:
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation,
previous_dilation=dilation, norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def resnet18_v1b(pretrained=False, **kwargs):
model = ResNetV1b(BasicBlockV1b, [2, 2, 2, 2], **kwargs)
if pretrained:
old_dict = model_zoo.load_url(model_urls['resnet18'])
model_dict = model.state_dict()
old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
model_dict.update(old_dict)
model.load_state_dict(model_dict)
return model
def resnet34_v1b(pretrained=False, **kwargs):
model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs)
if pretrained:
old_dict = model_zoo.load_url(model_urls['resnet34'])
model_dict = model.state_dict()
old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
model_dict.update(old_dict)
model.load_state_dict(model_dict)
return model
def resnet50_v1b(pretrained=False, **kwargs):
model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], **kwargs)
if pretrained:
old_dict = model_zoo.load_url(model_urls['resnet50'])
model_dict = model.state_dict()
old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
model_dict.update(old_dict)
model.load_state_dict(model_dict)
return model
def resnet101_v1b(pretrained=False, **kwargs):
model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], **kwargs)
if pretrained:
old_dict = model_zoo.load_url(model_urls['resnet101'])
model_dict = model.state_dict()
old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
model_dict.update(old_dict)
model.load_state_dict(model_dict)
return model
def resnet152_v1b(pretrained=False, **kwargs):
model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], **kwargs)
if pretrained:
old_dict = model_zoo.load_url(model_urls['resnet152'])
model_dict = model.state_dict()
old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
model_dict.update(old_dict)
model.load_state_dict(model_dict)
return model
def resnet50_v1s(pretrained=False, root='~/.torch/models', **kwargs):
model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, **kwargs)
if pretrained:
from ..model_store import get_resnet_file
model.load_state_dict(torch.load(get_resnet_file('resnet50', root=root)), strict=False)
return model
def resnet101_v1s(pretrained=False, root='~/.torch/models', **kwargs):
model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, **kwargs)
if pretrained:
from ..model_store import get_resnet_file
model.load_state_dict(torch.load(get_resnet_file('resnet101', root=root)), strict=False)
return model
def resnet152_v1s(pretrained=False, root='~/.torch/models', **kwargs):
model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, **kwargs)
if pretrained:
from ..model_store import get_resnet_file
model.load_state_dict(torch.load(get_resnet_file('resnet152', root=root)), strict=False)
return model
if __name__ == '__main__':
import torch
img = torch.randn(4, 3, 224, 224)
model = resnet50_v1b(True)
output = model(img)