412 lines
14 KiB
Python
412 lines
14 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
__all__ = ['Enc', 'FCAttention', 'Xception65', 'Xception71', 'get_xception', 'get_xception_71', 'get_xception_a']
|
|
|
|
|
|
class SeparableConv2d(nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=False, norm_layer=None):
|
|
super(SeparableConv2d, self).__init__()
|
|
self.kernel_size = kernel_size
|
|
self.dilation = dilation
|
|
|
|
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, 0, dilation, groups=in_channels,
|
|
bias=bias)
|
|
self.bn = norm_layer(in_channels)
|
|
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=bias)
|
|
|
|
def forward(self, x):
|
|
x = self.fix_padding(x, self.kernel_size, self.dilation)
|
|
x = self.conv1(x)
|
|
x = self.bn(x)
|
|
x = self.pointwise(x)
|
|
|
|
return x
|
|
|
|
def fix_padding(self, x, kernel_size, dilation):
|
|
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
|
|
pad_total = kernel_size_effective - 1
|
|
pad_beg = pad_total // 2
|
|
pad_end = pad_total - pad_beg
|
|
padded_inputs = F.pad(x, (pad_beg, pad_end, pad_beg, pad_end))
|
|
return padded_inputs
|
|
|
|
|
|
class Block(nn.Module):
|
|
def __init__(self, in_channels, out_channels, reps, stride=1, dilation=1, norm_layer=None,
|
|
start_with_relu=True, grow_first=True, is_last=False):
|
|
super(Block, self).__init__()
|
|
if out_channels != in_channels or stride != 1:
|
|
self.skip = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False)
|
|
self.skipbn = norm_layer(out_channels)
|
|
else:
|
|
self.skip = None
|
|
self.relu = nn.ReLU(True)
|
|
rep = list()
|
|
filters = in_channels
|
|
if grow_first:
|
|
if start_with_relu:
|
|
rep.append(self.relu)
|
|
rep.append(SeparableConv2d(in_channels, out_channels, 3, 1, dilation, norm_layer=norm_layer))
|
|
rep.append(norm_layer(out_channels))
|
|
filters = out_channels
|
|
for i in range(reps - 1):
|
|
if grow_first or start_with_relu:
|
|
rep.append(self.relu)
|
|
rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, norm_layer=norm_layer))
|
|
rep.append(norm_layer(filters))
|
|
if not grow_first:
|
|
rep.append(self.relu)
|
|
rep.append(SeparableConv2d(in_channels, out_channels, 3, 1, dilation, norm_layer=norm_layer))
|
|
if stride != 1:
|
|
rep.append(self.relu)
|
|
rep.append(SeparableConv2d(out_channels, out_channels, 3, stride, norm_layer=norm_layer))
|
|
rep.append(norm_layer(out_channels))
|
|
elif is_last:
|
|
rep.append(self.relu)
|
|
rep.append(SeparableConv2d(out_channels, out_channels, 3, 1, dilation, norm_layer=norm_layer))
|
|
rep.append(norm_layer(out_channels))
|
|
self.rep = nn.Sequential(*rep)
|
|
|
|
def forward(self, x):
|
|
out = self.rep(x)
|
|
if self.skip is not None:
|
|
skip = self.skipbn(self.skip(x))
|
|
else:
|
|
skip = x
|
|
out = out + skip
|
|
return out
|
|
|
|
|
|
class Xception65(nn.Module):
|
|
"""Modified Aligned Xception
|
|
"""
|
|
|
|
def __init__(self, num_classes=1000, output_stride=32, norm_layer=nn.BatchNorm2d):
|
|
super(Xception65, self).__init__()
|
|
if output_stride == 32:
|
|
entry_block3_stride = 2
|
|
exit_block20_stride = 2
|
|
middle_block_dilation = 1
|
|
exit_block_dilations = (1, 1)
|
|
elif output_stride == 16:
|
|
entry_block3_stride = 2
|
|
exit_block20_stride = 1
|
|
middle_block_dilation = 1
|
|
exit_block_dilations = (1, 2)
|
|
elif output_stride == 8:
|
|
entry_block3_stride = 1
|
|
exit_block20_stride = 1
|
|
middle_block_dilation = 2
|
|
exit_block_dilations = (2, 4)
|
|
else:
|
|
raise NotImplementedError
|
|
# Entry flow
|
|
self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False)
|
|
self.bn1 = norm_layer(32)
|
|
self.relu = nn.ReLU(True)
|
|
|
|
self.conv2 = nn.Conv2d(32, 64, 3, 1, 1, bias=False)
|
|
self.bn2 = norm_layer(64)
|
|
|
|
self.block1 = Block(64, 128, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False)
|
|
self.block2 = Block(128, 256, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False, grow_first=True)
|
|
self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, norm_layer=norm_layer,
|
|
start_with_relu=True, grow_first=True, is_last=True)
|
|
|
|
# Middle flow
|
|
midflow = list()
|
|
for i in range(4, 20):
|
|
midflow.append(Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, norm_layer=norm_layer,
|
|
start_with_relu=True, grow_first=True))
|
|
self.midflow = nn.Sequential(*midflow)
|
|
|
|
# Exit flow
|
|
self.block20 = Block(728, 1024, reps=2, stride=exit_block20_stride, dilation=exit_block_dilations[0],
|
|
norm_layer=norm_layer, start_with_relu=True, grow_first=False, is_last=True)
|
|
self.conv3 = SeparableConv2d(1024, 1536, 3, 1, dilation=exit_block_dilations[1], norm_layer=norm_layer)
|
|
self.bn3 = norm_layer(1536)
|
|
self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer)
|
|
self.bn4 = norm_layer(1536)
|
|
self.conv5 = SeparableConv2d(1536, 2048, 3, 1, dilation=exit_block_dilations[1], norm_layer=norm_layer)
|
|
self.bn5 = norm_layer(2048)
|
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
|
self.fc = nn.Linear(2048, num_classes)
|
|
|
|
def forward(self, x):
|
|
# Entry flow
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
|
|
x = self.conv2(x)
|
|
x = self.bn2(x)
|
|
x = self.relu(x)
|
|
|
|
x = self.block1(x)
|
|
x = self.relu(x)
|
|
# c1 = x
|
|
x = self.block2(x)
|
|
# c2 = x
|
|
x = self.block3(x)
|
|
|
|
# Middle flow
|
|
x = self.midflow(x)
|
|
# c3 = x
|
|
|
|
# Exit flow
|
|
x = self.block20(x)
|
|
x = self.relu(x)
|
|
x = self.conv3(x)
|
|
x = self.bn3(x)
|
|
x = self.relu(x)
|
|
|
|
x = self.conv4(x)
|
|
x = self.bn4(x)
|
|
x = self.relu(x)
|
|
|
|
x = self.conv5(x)
|
|
x = self.bn5(x)
|
|
x = self.relu(x)
|
|
|
|
x = self.avgpool(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.fc(x)
|
|
|
|
return x
|
|
|
|
|
|
class Xception71(nn.Module):
|
|
"""Modified Aligned Xception
|
|
"""
|
|
|
|
def __init__(self, num_classes=1000, output_stride=32, norm_layer=nn.BatchNorm2d):
|
|
super(Xception71, self).__init__()
|
|
if output_stride == 32:
|
|
entry_block3_stride = 2
|
|
exit_block20_stride = 2
|
|
middle_block_dilation = 1
|
|
exit_block_dilations = (1, 1)
|
|
elif output_stride == 16:
|
|
entry_block3_stride = 2
|
|
exit_block20_stride = 1
|
|
middle_block_dilation = 1
|
|
exit_block_dilations = (1, 2)
|
|
elif output_stride == 8:
|
|
entry_block3_stride = 1
|
|
exit_block20_stride = 1
|
|
middle_block_dilation = 2
|
|
exit_block_dilations = (2, 4)
|
|
else:
|
|
raise NotImplementedError
|
|
# Entry flow
|
|
self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False)
|
|
self.bn1 = norm_layer(32)
|
|
self.relu = nn.ReLU(True)
|
|
|
|
self.conv2 = nn.Conv2d(32, 64, 3, 1, 1, bias=False)
|
|
self.bn2 = norm_layer(64)
|
|
|
|
self.block1 = Block(64, 128, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False)
|
|
self.block2 = nn.Sequential(
|
|
Block(128, 256, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False, grow_first=True),
|
|
Block(256, 728, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False, grow_first=True))
|
|
self.block3 = Block(728, 728, reps=2, stride=entry_block3_stride, norm_layer=norm_layer,
|
|
start_with_relu=True, grow_first=True, is_last=True)
|
|
|
|
# Middle flow
|
|
midflow = list()
|
|
for i in range(4, 20):
|
|
midflow.append(Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, norm_layer=norm_layer,
|
|
start_with_relu=True, grow_first=True))
|
|
self.midflow = nn.Sequential(*midflow)
|
|
|
|
# Exit flow
|
|
self.block20 = Block(728, 1024, reps=2, stride=exit_block20_stride, dilation=exit_block_dilations[0],
|
|
norm_layer=norm_layer, start_with_relu=True, grow_first=False, is_last=True)
|
|
self.conv3 = SeparableConv2d(1024, 1536, 3, 1, dilation=exit_block_dilations[1], norm_layer=norm_layer)
|
|
self.bn3 = norm_layer(1536)
|
|
self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer)
|
|
self.bn4 = norm_layer(1536)
|
|
self.conv5 = SeparableConv2d(1536, 2048, 3, 1, dilation=exit_block_dilations[1], norm_layer=norm_layer)
|
|
self.bn5 = norm_layer(2048)
|
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
|
self.fc = nn.Linear(2048, num_classes)
|
|
|
|
def forward(self, x):
|
|
# Entry flow
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
|
|
x = self.conv2(x)
|
|
x = self.bn2(x)
|
|
x = self.relu(x)
|
|
|
|
x = self.block1(x)
|
|
x = self.relu(x)
|
|
# c1 = x
|
|
x = self.block2(x)
|
|
# c2 = x
|
|
x = self.block3(x)
|
|
|
|
# Middle flow
|
|
x = self.midflow(x)
|
|
# c3 = x
|
|
|
|
# Exit flow
|
|
x = self.block20(x)
|
|
x = self.relu(x)
|
|
x = self.conv3(x)
|
|
x = self.bn3(x)
|
|
x = self.relu(x)
|
|
|
|
x = self.conv4(x)
|
|
x = self.bn4(x)
|
|
x = self.relu(x)
|
|
|
|
x = self.conv5(x)
|
|
x = self.bn5(x)
|
|
x = self.relu(x)
|
|
|
|
x = self.avgpool(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.fc(x)
|
|
|
|
return x
|
|
|
|
|
|
# -------------------------------------------------
|
|
# For DFANet
|
|
# -------------------------------------------------
|
|
class BlockA(nn.Module):
|
|
def __init__(self, in_channels, out_channels, stride=1, dilation=1, norm_layer=None, start_with_relu=True):
|
|
super(BlockA, self).__init__()
|
|
if out_channels != in_channels or stride != 1:
|
|
self.skip = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False)
|
|
self.skipbn = norm_layer(out_channels)
|
|
else:
|
|
self.skip = None
|
|
self.relu = nn.ReLU(False)
|
|
rep = list()
|
|
inter_channels = out_channels // 4
|
|
|
|
if start_with_relu:
|
|
rep.append(self.relu)
|
|
rep.append(SeparableConv2d(in_channels, inter_channels, 3, 1, dilation, norm_layer=norm_layer))
|
|
rep.append(norm_layer(inter_channels))
|
|
|
|
rep.append(self.relu)
|
|
rep.append(SeparableConv2d(inter_channels, inter_channels, 3, 1, dilation, norm_layer=norm_layer))
|
|
rep.append(norm_layer(inter_channels))
|
|
|
|
if stride != 1:
|
|
rep.append(self.relu)
|
|
rep.append(SeparableConv2d(inter_channels, out_channels, 3, stride, norm_layer=norm_layer))
|
|
rep.append(norm_layer(out_channels))
|
|
else:
|
|
rep.append(self.relu)
|
|
rep.append(SeparableConv2d(inter_channels, out_channels, 3, 1, norm_layer=norm_layer))
|
|
rep.append(norm_layer(out_channels))
|
|
self.rep = nn.Sequential(*rep)
|
|
|
|
def forward(self, x):
|
|
out = self.rep(x)
|
|
if self.skip is not None:
|
|
skip = self.skipbn(self.skip(x))
|
|
else:
|
|
skip = x
|
|
out = out + skip
|
|
return out
|
|
|
|
|
|
class Enc(nn.Module):
|
|
def __init__(self, in_channels, out_channels, blocks, norm_layer=nn.BatchNorm2d):
|
|
super(Enc, self).__init__()
|
|
block = list()
|
|
block.append(BlockA(in_channels, out_channels, 2, norm_layer=norm_layer))
|
|
for i in range(blocks - 1):
|
|
block.append(BlockA(out_channels, out_channels, 1, norm_layer=norm_layer))
|
|
self.block = nn.Sequential(*block)
|
|
|
|
def forward(self, x):
|
|
return self.block(x)
|
|
|
|
|
|
class FCAttention(nn.Module):
|
|
def __init__(self, in_channels, norm_layer=nn.BatchNorm2d):
|
|
super(FCAttention, self).__init__()
|
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
|
self.fc = nn.Linear(in_channels, 1000)
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(1000, in_channels, 1, bias=False),
|
|
norm_layer(in_channels),
|
|
nn.ReLU(False))
|
|
|
|
def forward(self, x):
|
|
n, c, _, _ = x.size()
|
|
att = self.avgpool(x).view(n, c)
|
|
att = self.fc(att).view(n, 1000, 1, 1)
|
|
att = self.conv(att)
|
|
return x * att.expand_as(x)
|
|
|
|
|
|
class XceptionA(nn.Module):
|
|
def __init__(self, num_classes=1000, norm_layer=nn.BatchNorm2d):
|
|
super(XceptionA, self).__init__()
|
|
self.conv1 = nn.Sequential(nn.Conv2d(3, 8, 3, 2, 1, bias=False),
|
|
norm_layer(8),
|
|
nn.ReLU(True))
|
|
|
|
self.enc2 = Enc(8, 48, 4, norm_layer=norm_layer)
|
|
self.enc3 = Enc(48, 96, 6, norm_layer=norm_layer)
|
|
self.enc4 = Enc(96, 192, 4, norm_layer=norm_layer)
|
|
|
|
self.fca = FCAttention(192, norm_layer=norm_layer)
|
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
|
self.fc = nn.Linear(192, num_classes)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
|
|
x = self.enc2(x)
|
|
x = self.enc3(x)
|
|
x = self.enc4(x)
|
|
x = self.fca(x)
|
|
|
|
x = self.avgpool(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.fc(x)
|
|
|
|
return x
|
|
|
|
|
|
# Constructor
|
|
def get_xception(pretrained=False, root='~/.torch/models', **kwargs):
|
|
model = Xception65(**kwargs)
|
|
if pretrained:
|
|
from ..model_store import get_model_file
|
|
model.load_state_dict(torch.load(get_model_file('xception', root=root)))
|
|
return model
|
|
|
|
|
|
def get_xception_71(pretrained=False, root='~/.torch/models', **kwargs):
|
|
model = Xception71(**kwargs)
|
|
if pretrained:
|
|
from ..model_store import get_model_file
|
|
model.load_state_dict(torch.load(get_model_file('xception71', root=root)))
|
|
return model
|
|
|
|
|
|
def get_xception_a(pretrained=False, root='~/.torch/models', **kwargs):
|
|
model = XceptionA(**kwargs)
|
|
if pretrained:
|
|
from ..model_store import get_model_file
|
|
model.load_state_dict(torch.load(get_model_file('xception_a', root=root)))
|
|
return model
|
|
|
|
|
|
if __name__ == '__main__':
|
|
model = get_xception_a()
|