AIlib2/segutils/core/models/base_models/xception.py

412 lines
14 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['Enc', 'FCAttention', 'Xception65', 'Xception71', 'get_xception', 'get_xception_71', 'get_xception_a']
class SeparableConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=False, norm_layer=None):
super(SeparableConv2d, self).__init__()
self.kernel_size = kernel_size
self.dilation = dilation
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, 0, dilation, groups=in_channels,
bias=bias)
self.bn = norm_layer(in_channels)
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=bias)
def forward(self, x):
x = self.fix_padding(x, self.kernel_size, self.dilation)
x = self.conv1(x)
x = self.bn(x)
x = self.pointwise(x)
return x
def fix_padding(self, x, kernel_size, dilation):
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
pad_total = kernel_size_effective - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
padded_inputs = F.pad(x, (pad_beg, pad_end, pad_beg, pad_end))
return padded_inputs
class Block(nn.Module):
def __init__(self, in_channels, out_channels, reps, stride=1, dilation=1, norm_layer=None,
start_with_relu=True, grow_first=True, is_last=False):
super(Block, self).__init__()
if out_channels != in_channels or stride != 1:
self.skip = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False)
self.skipbn = norm_layer(out_channels)
else:
self.skip = None
self.relu = nn.ReLU(True)
rep = list()
filters = in_channels
if grow_first:
if start_with_relu:
rep.append(self.relu)
rep.append(SeparableConv2d(in_channels, out_channels, 3, 1, dilation, norm_layer=norm_layer))
rep.append(norm_layer(out_channels))
filters = out_channels
for i in range(reps - 1):
if grow_first or start_with_relu:
rep.append(self.relu)
rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, norm_layer=norm_layer))
rep.append(norm_layer(filters))
if not grow_first:
rep.append(self.relu)
rep.append(SeparableConv2d(in_channels, out_channels, 3, 1, dilation, norm_layer=norm_layer))
if stride != 1:
rep.append(self.relu)
rep.append(SeparableConv2d(out_channels, out_channels, 3, stride, norm_layer=norm_layer))
rep.append(norm_layer(out_channels))
elif is_last:
rep.append(self.relu)
rep.append(SeparableConv2d(out_channels, out_channels, 3, 1, dilation, norm_layer=norm_layer))
rep.append(norm_layer(out_channels))
self.rep = nn.Sequential(*rep)
def forward(self, x):
out = self.rep(x)
if self.skip is not None:
skip = self.skipbn(self.skip(x))
else:
skip = x
out = out + skip
return out
class Xception65(nn.Module):
"""Modified Aligned Xception
"""
def __init__(self, num_classes=1000, output_stride=32, norm_layer=nn.BatchNorm2d):
super(Xception65, self).__init__()
if output_stride == 32:
entry_block3_stride = 2
exit_block20_stride = 2
middle_block_dilation = 1
exit_block_dilations = (1, 1)
elif output_stride == 16:
entry_block3_stride = 2
exit_block20_stride = 1
middle_block_dilation = 1
exit_block_dilations = (1, 2)
elif output_stride == 8:
entry_block3_stride = 1
exit_block20_stride = 1
middle_block_dilation = 2
exit_block_dilations = (2, 4)
else:
raise NotImplementedError
# Entry flow
self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False)
self.bn1 = norm_layer(32)
self.relu = nn.ReLU(True)
self.conv2 = nn.Conv2d(32, 64, 3, 1, 1, bias=False)
self.bn2 = norm_layer(64)
self.block1 = Block(64, 128, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False)
self.block2 = Block(128, 256, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False, grow_first=True)
self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, norm_layer=norm_layer,
start_with_relu=True, grow_first=True, is_last=True)
# Middle flow
midflow = list()
for i in range(4, 20):
midflow.append(Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, norm_layer=norm_layer,
start_with_relu=True, grow_first=True))
self.midflow = nn.Sequential(*midflow)
# Exit flow
self.block20 = Block(728, 1024, reps=2, stride=exit_block20_stride, dilation=exit_block_dilations[0],
norm_layer=norm_layer, start_with_relu=True, grow_first=False, is_last=True)
self.conv3 = SeparableConv2d(1024, 1536, 3, 1, dilation=exit_block_dilations[1], norm_layer=norm_layer)
self.bn3 = norm_layer(1536)
self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer)
self.bn4 = norm_layer(1536)
self.conv5 = SeparableConv2d(1536, 2048, 3, 1, dilation=exit_block_dilations[1], norm_layer=norm_layer)
self.bn5 = norm_layer(2048)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(2048, num_classes)
def forward(self, x):
# Entry flow
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.block1(x)
x = self.relu(x)
# c1 = x
x = self.block2(x)
# c2 = x
x = self.block3(x)
# Middle flow
x = self.midflow(x)
# c3 = x
# Exit flow
x = self.block20(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = self.conv4(x)
x = self.bn4(x)
x = self.relu(x)
x = self.conv5(x)
x = self.bn5(x)
x = self.relu(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
class Xception71(nn.Module):
"""Modified Aligned Xception
"""
def __init__(self, num_classes=1000, output_stride=32, norm_layer=nn.BatchNorm2d):
super(Xception71, self).__init__()
if output_stride == 32:
entry_block3_stride = 2
exit_block20_stride = 2
middle_block_dilation = 1
exit_block_dilations = (1, 1)
elif output_stride == 16:
entry_block3_stride = 2
exit_block20_stride = 1
middle_block_dilation = 1
exit_block_dilations = (1, 2)
elif output_stride == 8:
entry_block3_stride = 1
exit_block20_stride = 1
middle_block_dilation = 2
exit_block_dilations = (2, 4)
else:
raise NotImplementedError
# Entry flow
self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False)
self.bn1 = norm_layer(32)
self.relu = nn.ReLU(True)
self.conv2 = nn.Conv2d(32, 64, 3, 1, 1, bias=False)
self.bn2 = norm_layer(64)
self.block1 = Block(64, 128, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False)
self.block2 = nn.Sequential(
Block(128, 256, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False, grow_first=True),
Block(256, 728, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False, grow_first=True))
self.block3 = Block(728, 728, reps=2, stride=entry_block3_stride, norm_layer=norm_layer,
start_with_relu=True, grow_first=True, is_last=True)
# Middle flow
midflow = list()
for i in range(4, 20):
midflow.append(Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, norm_layer=norm_layer,
start_with_relu=True, grow_first=True))
self.midflow = nn.Sequential(*midflow)
# Exit flow
self.block20 = Block(728, 1024, reps=2, stride=exit_block20_stride, dilation=exit_block_dilations[0],
norm_layer=norm_layer, start_with_relu=True, grow_first=False, is_last=True)
self.conv3 = SeparableConv2d(1024, 1536, 3, 1, dilation=exit_block_dilations[1], norm_layer=norm_layer)
self.bn3 = norm_layer(1536)
self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer)
self.bn4 = norm_layer(1536)
self.conv5 = SeparableConv2d(1536, 2048, 3, 1, dilation=exit_block_dilations[1], norm_layer=norm_layer)
self.bn5 = norm_layer(2048)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(2048, num_classes)
def forward(self, x):
# Entry flow
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.block1(x)
x = self.relu(x)
# c1 = x
x = self.block2(x)
# c2 = x
x = self.block3(x)
# Middle flow
x = self.midflow(x)
# c3 = x
# Exit flow
x = self.block20(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = self.conv4(x)
x = self.bn4(x)
x = self.relu(x)
x = self.conv5(x)
x = self.bn5(x)
x = self.relu(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# -------------------------------------------------
# For DFANet
# -------------------------------------------------
class BlockA(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, dilation=1, norm_layer=None, start_with_relu=True):
super(BlockA, self).__init__()
if out_channels != in_channels or stride != 1:
self.skip = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False)
self.skipbn = norm_layer(out_channels)
else:
self.skip = None
self.relu = nn.ReLU(False)
rep = list()
inter_channels = out_channels // 4
if start_with_relu:
rep.append(self.relu)
rep.append(SeparableConv2d(in_channels, inter_channels, 3, 1, dilation, norm_layer=norm_layer))
rep.append(norm_layer(inter_channels))
rep.append(self.relu)
rep.append(SeparableConv2d(inter_channels, inter_channels, 3, 1, dilation, norm_layer=norm_layer))
rep.append(norm_layer(inter_channels))
if stride != 1:
rep.append(self.relu)
rep.append(SeparableConv2d(inter_channels, out_channels, 3, stride, norm_layer=norm_layer))
rep.append(norm_layer(out_channels))
else:
rep.append(self.relu)
rep.append(SeparableConv2d(inter_channels, out_channels, 3, 1, norm_layer=norm_layer))
rep.append(norm_layer(out_channels))
self.rep = nn.Sequential(*rep)
def forward(self, x):
out = self.rep(x)
if self.skip is not None:
skip = self.skipbn(self.skip(x))
else:
skip = x
out = out + skip
return out
class Enc(nn.Module):
def __init__(self, in_channels, out_channels, blocks, norm_layer=nn.BatchNorm2d):
super(Enc, self).__init__()
block = list()
block.append(BlockA(in_channels, out_channels, 2, norm_layer=norm_layer))
for i in range(blocks - 1):
block.append(BlockA(out_channels, out_channels, 1, norm_layer=norm_layer))
self.block = nn.Sequential(*block)
def forward(self, x):
return self.block(x)
class FCAttention(nn.Module):
def __init__(self, in_channels, norm_layer=nn.BatchNorm2d):
super(FCAttention, self).__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(in_channels, 1000)
self.conv = nn.Sequential(
nn.Conv2d(1000, in_channels, 1, bias=False),
norm_layer(in_channels),
nn.ReLU(False))
def forward(self, x):
n, c, _, _ = x.size()
att = self.avgpool(x).view(n, c)
att = self.fc(att).view(n, 1000, 1, 1)
att = self.conv(att)
return x * att.expand_as(x)
class XceptionA(nn.Module):
def __init__(self, num_classes=1000, norm_layer=nn.BatchNorm2d):
super(XceptionA, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(3, 8, 3, 2, 1, bias=False),
norm_layer(8),
nn.ReLU(True))
self.enc2 = Enc(8, 48, 4, norm_layer=norm_layer)
self.enc3 = Enc(48, 96, 6, norm_layer=norm_layer)
self.enc4 = Enc(96, 192, 4, norm_layer=norm_layer)
self.fca = FCAttention(192, norm_layer=norm_layer)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(192, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.enc2(x)
x = self.enc3(x)
x = self.enc4(x)
x = self.fca(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# Constructor
def get_xception(pretrained=False, root='~/.torch/models', **kwargs):
model = Xception65(**kwargs)
if pretrained:
from ..model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('xception', root=root)))
return model
def get_xception_71(pretrained=False, root='~/.torch/models', **kwargs):
model = Xception71(**kwargs)
if pretrained:
from ..model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('xception71', root=root)))
return model
def get_xception_a(pretrained=False, root='~/.torch/models', **kwargs):
model = XceptionA(**kwargs)
if pretrained:
from ..model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('xception_a', root=root)))
return model
if __name__ == '__main__':
model = get_xception_a()