69 lines
2.7 KiB
Python
69 lines
2.7 KiB
Python
"""Joint Pyramid Upsampling"""
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
__all__ = ['JPU']
|
|
|
|
|
|
class SeparableConv2d(nn.Module):
|
|
def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1,
|
|
dilation=1, bias=False, norm_layer=nn.BatchNorm2d):
|
|
super(SeparableConv2d, self).__init__()
|
|
self.conv = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, groups=inplanes, bias=bias)
|
|
self.bn = norm_layer(inplanes)
|
|
self.pointwise = nn.Conv2d(inplanes, planes, 1, bias=bias)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
x = self.pointwise(x)
|
|
return x
|
|
|
|
|
|
# copy from: https://github.com/wuhuikai/FastFCN/blob/master/encoding/nn/customize.py
|
|
class JPU(nn.Module):
|
|
def __init__(self, in_channels, width=512, norm_layer=nn.BatchNorm2d, **kwargs):
|
|
super(JPU, self).__init__()
|
|
|
|
self.conv5 = nn.Sequential(
|
|
nn.Conv2d(in_channels[-1], width, 3, padding=1, bias=False),
|
|
norm_layer(width),
|
|
nn.ReLU(True))
|
|
self.conv4 = nn.Sequential(
|
|
nn.Conv2d(in_channels[-2], width, 3, padding=1, bias=False),
|
|
norm_layer(width),
|
|
nn.ReLU(True))
|
|
self.conv3 = nn.Sequential(
|
|
nn.Conv2d(in_channels[-3], width, 3, padding=1, bias=False),
|
|
norm_layer(width),
|
|
nn.ReLU(True))
|
|
|
|
self.dilation1 = nn.Sequential(
|
|
SeparableConv2d(3 * width, width, 3, padding=1, dilation=1, bias=False),
|
|
norm_layer(width),
|
|
nn.ReLU(True))
|
|
self.dilation2 = nn.Sequential(
|
|
SeparableConv2d(3 * width, width, 3, padding=2, dilation=2, bias=False),
|
|
norm_layer(width),
|
|
nn.ReLU(True))
|
|
self.dilation3 = nn.Sequential(
|
|
SeparableConv2d(3 * width, width, 3, padding=4, dilation=4, bias=False),
|
|
norm_layer(width),
|
|
nn.ReLU(True))
|
|
self.dilation4 = nn.Sequential(
|
|
SeparableConv2d(3 * width, width, 3, padding=8, dilation=8, bias=False),
|
|
norm_layer(width),
|
|
nn.ReLU(True))
|
|
|
|
def forward(self, *inputs):
|
|
feats = [self.conv5(inputs[-1]), self.conv4(inputs[-2]), self.conv3(inputs[-3])]
|
|
size = feats[-1].size()[2:]
|
|
feats[-2] = F.interpolate(feats[-2], size, mode='bilinear', align_corners=True)
|
|
feats[-3] = F.interpolate(feats[-3], size, mode='bilinear', align_corners=True)
|
|
feat = torch.cat(feats, dim=1)
|
|
feat = torch.cat([self.dilation1(feat), self.dilation2(feat), self.dilation3(feat), self.dilation4(feat)],
|
|
dim=1)
|
|
|
|
return inputs[0], inputs[1], inputs[2], feat
|