import torch import torch.nn as nn import torch.nn.functional as F __all__ = ['Enc', 'FCAttention', 'Xception65', 'Xception71', 'get_xception', 'get_xception_71', 'get_xception_a'] class SeparableConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=False, norm_layer=None): super(SeparableConv2d, self).__init__() self.kernel_size = kernel_size self.dilation = dilation self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, 0, dilation, groups=in_channels, bias=bias) self.bn = norm_layer(in_channels) self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=bias) def forward(self, x): x = self.fix_padding(x, self.kernel_size, self.dilation) x = self.conv1(x) x = self.bn(x) x = self.pointwise(x) return x def fix_padding(self, x, kernel_size, dilation): kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) pad_total = kernel_size_effective - 1 pad_beg = pad_total // 2 pad_end = pad_total - pad_beg padded_inputs = F.pad(x, (pad_beg, pad_end, pad_beg, pad_end)) return padded_inputs class Block(nn.Module): def __init__(self, in_channels, out_channels, reps, stride=1, dilation=1, norm_layer=None, start_with_relu=True, grow_first=True, is_last=False): super(Block, self).__init__() if out_channels != in_channels or stride != 1: self.skip = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False) self.skipbn = norm_layer(out_channels) else: self.skip = None self.relu = nn.ReLU(True) rep = list() filters = in_channels if grow_first: if start_with_relu: rep.append(self.relu) rep.append(SeparableConv2d(in_channels, out_channels, 3, 1, dilation, norm_layer=norm_layer)) rep.append(norm_layer(out_channels)) filters = out_channels for i in range(reps - 1): if grow_first or start_with_relu: rep.append(self.relu) rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, norm_layer=norm_layer)) rep.append(norm_layer(filters)) if not grow_first: rep.append(self.relu) rep.append(SeparableConv2d(in_channels, out_channels, 3, 1, dilation, norm_layer=norm_layer)) if stride != 1: rep.append(self.relu) rep.append(SeparableConv2d(out_channels, out_channels, 3, stride, norm_layer=norm_layer)) rep.append(norm_layer(out_channels)) elif is_last: rep.append(self.relu) rep.append(SeparableConv2d(out_channels, out_channels, 3, 1, dilation, norm_layer=norm_layer)) rep.append(norm_layer(out_channels)) self.rep = nn.Sequential(*rep) def forward(self, x): out = self.rep(x) if self.skip is not None: skip = self.skipbn(self.skip(x)) else: skip = x out = out + skip return out class Xception65(nn.Module): """Modified Aligned Xception """ def __init__(self, num_classes=1000, output_stride=32, norm_layer=nn.BatchNorm2d): super(Xception65, self).__init__() if output_stride == 32: entry_block3_stride = 2 exit_block20_stride = 2 middle_block_dilation = 1 exit_block_dilations = (1, 1) elif output_stride == 16: entry_block3_stride = 2 exit_block20_stride = 1 middle_block_dilation = 1 exit_block_dilations = (1, 2) elif output_stride == 8: entry_block3_stride = 1 exit_block20_stride = 1 middle_block_dilation = 2 exit_block_dilations = (2, 4) else: raise NotImplementedError # Entry flow self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False) self.bn1 = norm_layer(32) self.relu = nn.ReLU(True) self.conv2 = nn.Conv2d(32, 64, 3, 1, 1, bias=False) self.bn2 = norm_layer(64) self.block1 = Block(64, 128, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False) self.block2 = Block(128, 256, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False, grow_first=True) self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, norm_layer=norm_layer, start_with_relu=True, grow_first=True, is_last=True) # Middle flow midflow = list() for i in range(4, 20): midflow.append(Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, norm_layer=norm_layer, start_with_relu=True, grow_first=True)) self.midflow = nn.Sequential(*midflow) # Exit flow self.block20 = Block(728, 1024, reps=2, stride=exit_block20_stride, dilation=exit_block_dilations[0], norm_layer=norm_layer, start_with_relu=True, grow_first=False, is_last=True) self.conv3 = SeparableConv2d(1024, 1536, 3, 1, dilation=exit_block_dilations[1], norm_layer=norm_layer) self.bn3 = norm_layer(1536) self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer) self.bn4 = norm_layer(1536) self.conv5 = SeparableConv2d(1536, 2048, 3, 1, dilation=exit_block_dilations[1], norm_layer=norm_layer) self.bn5 = norm_layer(2048) self.avgpool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(2048, num_classes) def forward(self, x): # Entry flow x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.block1(x) x = self.relu(x) # c1 = x x = self.block2(x) # c2 = x x = self.block3(x) # Middle flow x = self.midflow(x) # c3 = x # Exit flow x = self.block20(x) x = self.relu(x) x = self.conv3(x) x = self.bn3(x) x = self.relu(x) x = self.conv4(x) x = self.bn4(x) x = self.relu(x) x = self.conv5(x) x = self.bn5(x) x = self.relu(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x class Xception71(nn.Module): """Modified Aligned Xception """ def __init__(self, num_classes=1000, output_stride=32, norm_layer=nn.BatchNorm2d): super(Xception71, self).__init__() if output_stride == 32: entry_block3_stride = 2 exit_block20_stride = 2 middle_block_dilation = 1 exit_block_dilations = (1, 1) elif output_stride == 16: entry_block3_stride = 2 exit_block20_stride = 1 middle_block_dilation = 1 exit_block_dilations = (1, 2) elif output_stride == 8: entry_block3_stride = 1 exit_block20_stride = 1 middle_block_dilation = 2 exit_block_dilations = (2, 4) else: raise NotImplementedError # Entry flow self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False) self.bn1 = norm_layer(32) self.relu = nn.ReLU(True) self.conv2 = nn.Conv2d(32, 64, 3, 1, 1, bias=False) self.bn2 = norm_layer(64) self.block1 = Block(64, 128, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False) self.block2 = nn.Sequential( Block(128, 256, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False, grow_first=True), Block(256, 728, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False, grow_first=True)) self.block3 = Block(728, 728, reps=2, stride=entry_block3_stride, norm_layer=norm_layer, start_with_relu=True, grow_first=True, is_last=True) # Middle flow midflow = list() for i in range(4, 20): midflow.append(Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, norm_layer=norm_layer, start_with_relu=True, grow_first=True)) self.midflow = nn.Sequential(*midflow) # Exit flow self.block20 = Block(728, 1024, reps=2, stride=exit_block20_stride, dilation=exit_block_dilations[0], norm_layer=norm_layer, start_with_relu=True, grow_first=False, is_last=True) self.conv3 = SeparableConv2d(1024, 1536, 3, 1, dilation=exit_block_dilations[1], norm_layer=norm_layer) self.bn3 = norm_layer(1536) self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer) self.bn4 = norm_layer(1536) self.conv5 = SeparableConv2d(1536, 2048, 3, 1, dilation=exit_block_dilations[1], norm_layer=norm_layer) self.bn5 = norm_layer(2048) self.avgpool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(2048, num_classes) def forward(self, x): # Entry flow x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.block1(x) x = self.relu(x) # c1 = x x = self.block2(x) # c2 = x x = self.block3(x) # Middle flow x = self.midflow(x) # c3 = x # Exit flow x = self.block20(x) x = self.relu(x) x = self.conv3(x) x = self.bn3(x) x = self.relu(x) x = self.conv4(x) x = self.bn4(x) x = self.relu(x) x = self.conv5(x) x = self.bn5(x) x = self.relu(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x # ------------------------------------------------- # For DFANet # ------------------------------------------------- class BlockA(nn.Module): def __init__(self, in_channels, out_channels, stride=1, dilation=1, norm_layer=None, start_with_relu=True): super(BlockA, self).__init__() if out_channels != in_channels or stride != 1: self.skip = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False) self.skipbn = norm_layer(out_channels) else: self.skip = None self.relu = nn.ReLU(False) rep = list() inter_channels = out_channels // 4 if start_with_relu: rep.append(self.relu) rep.append(SeparableConv2d(in_channels, inter_channels, 3, 1, dilation, norm_layer=norm_layer)) rep.append(norm_layer(inter_channels)) rep.append(self.relu) rep.append(SeparableConv2d(inter_channels, inter_channels, 3, 1, dilation, norm_layer=norm_layer)) rep.append(norm_layer(inter_channels)) if stride != 1: rep.append(self.relu) rep.append(SeparableConv2d(inter_channels, out_channels, 3, stride, norm_layer=norm_layer)) rep.append(norm_layer(out_channels)) else: rep.append(self.relu) rep.append(SeparableConv2d(inter_channels, out_channels, 3, 1, norm_layer=norm_layer)) rep.append(norm_layer(out_channels)) self.rep = nn.Sequential(*rep) def forward(self, x): out = self.rep(x) if self.skip is not None: skip = self.skipbn(self.skip(x)) else: skip = x out = out + skip return out class Enc(nn.Module): def __init__(self, in_channels, out_channels, blocks, norm_layer=nn.BatchNorm2d): super(Enc, self).__init__() block = list() block.append(BlockA(in_channels, out_channels, 2, norm_layer=norm_layer)) for i in range(blocks - 1): block.append(BlockA(out_channels, out_channels, 1, norm_layer=norm_layer)) self.block = nn.Sequential(*block) def forward(self, x): return self.block(x) class FCAttention(nn.Module): def __init__(self, in_channels, norm_layer=nn.BatchNorm2d): super(FCAttention, self).__init__() self.avgpool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(in_channels, 1000) self.conv = nn.Sequential( nn.Conv2d(1000, in_channels, 1, bias=False), norm_layer(in_channels), nn.ReLU(False)) def forward(self, x): n, c, _, _ = x.size() att = self.avgpool(x).view(n, c) att = self.fc(att).view(n, 1000, 1, 1) att = self.conv(att) return x * att.expand_as(x) class XceptionA(nn.Module): def __init__(self, num_classes=1000, norm_layer=nn.BatchNorm2d): super(XceptionA, self).__init__() self.conv1 = nn.Sequential(nn.Conv2d(3, 8, 3, 2, 1, bias=False), norm_layer(8), nn.ReLU(True)) self.enc2 = Enc(8, 48, 4, norm_layer=norm_layer) self.enc3 = Enc(48, 96, 6, norm_layer=norm_layer) self.enc4 = Enc(96, 192, 4, norm_layer=norm_layer) self.fca = FCAttention(192, norm_layer=norm_layer) self.avgpool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(192, num_classes) def forward(self, x): x = self.conv1(x) x = self.enc2(x) x = self.enc3(x) x = self.enc4(x) x = self.fca(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x # Constructor def get_xception(pretrained=False, root='~/.torch/models', **kwargs): model = Xception65(**kwargs) if pretrained: from ..model_store import get_model_file model.load_state_dict(torch.load(get_model_file('xception', root=root))) return model def get_xception_71(pretrained=False, root='~/.torch/models', **kwargs): model = Xception71(**kwargs) if pretrained: from ..model_store import get_model_file model.load_state_dict(torch.load(get_model_file('xception71', root=root))) return model def get_xception_a(pretrained=False, root='~/.torch/models', **kwargs): model = XceptionA(**kwargs) if pretrained: from ..model_store import get_model_file model.load_state_dict(torch.load(get_model_file('xception_a', root=root))) return model if __name__ == '__main__': model = get_xception_a()