"""Joint Pyramid Upsampling""" import torch import torch.nn as nn import torch.nn.functional as F __all__ = ['JPU'] class SeparableConv2d(nn.Module): def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, norm_layer=nn.BatchNorm2d): super(SeparableConv2d, self).__init__() self.conv = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, groups=inplanes, bias=bias) self.bn = norm_layer(inplanes) self.pointwise = nn.Conv2d(inplanes, planes, 1, bias=bias) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.pointwise(x) return x # copy from: https://github.com/wuhuikai/FastFCN/blob/master/encoding/nn/customize.py class JPU(nn.Module): def __init__(self, in_channels, width=512, norm_layer=nn.BatchNorm2d, **kwargs): super(JPU, self).__init__() self.conv5 = nn.Sequential( nn.Conv2d(in_channels[-1], width, 3, padding=1, bias=False), norm_layer(width), nn.ReLU(True)) self.conv4 = nn.Sequential( nn.Conv2d(in_channels[-2], width, 3, padding=1, bias=False), norm_layer(width), nn.ReLU(True)) self.conv3 = nn.Sequential( nn.Conv2d(in_channels[-3], width, 3, padding=1, bias=False), norm_layer(width), nn.ReLU(True)) self.dilation1 = nn.Sequential( SeparableConv2d(3 * width, width, 3, padding=1, dilation=1, bias=False), norm_layer(width), nn.ReLU(True)) self.dilation2 = nn.Sequential( SeparableConv2d(3 * width, width, 3, padding=2, dilation=2, bias=False), norm_layer(width), nn.ReLU(True)) self.dilation3 = nn.Sequential( SeparableConv2d(3 * width, width, 3, padding=4, dilation=4, bias=False), norm_layer(width), nn.ReLU(True)) self.dilation4 = nn.Sequential( SeparableConv2d(3 * width, width, 3, padding=8, dilation=8, bias=False), norm_layer(width), nn.ReLU(True)) def forward(self, *inputs): feats = [self.conv5(inputs[-1]), self.conv4(inputs[-2]), self.conv3(inputs[-3])] size = feats[-1].size()[2:] feats[-2] = F.interpolate(feats[-2], size, mode='bilinear', align_corners=True) feats[-3] = F.interpolate(feats[-3], size, mode='bilinear', align_corners=True) feat = torch.cat(feats, dim=1) feat = torch.cat([self.dilation1(feat), self.dilation2(feat), self.dilation3(feat), self.dilation4(feat)], dim=1) return inputs[0], inputs[1], inputs[2], feat