AIlib2/segutils/core/nn/jpu.py

69 lines
2.7 KiB
Python

"""Joint Pyramid Upsampling"""
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['JPU']
class SeparableConv2d(nn.Module):
def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1,
dilation=1, bias=False, norm_layer=nn.BatchNorm2d):
super(SeparableConv2d, self).__init__()
self.conv = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, groups=inplanes, bias=bias)
self.bn = norm_layer(inplanes)
self.pointwise = nn.Conv2d(inplanes, planes, 1, bias=bias)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.pointwise(x)
return x
# copy from: https://github.com/wuhuikai/FastFCN/blob/master/encoding/nn/customize.py
class JPU(nn.Module):
def __init__(self, in_channels, width=512, norm_layer=nn.BatchNorm2d, **kwargs):
super(JPU, self).__init__()
self.conv5 = nn.Sequential(
nn.Conv2d(in_channels[-1], width, 3, padding=1, bias=False),
norm_layer(width),
nn.ReLU(True))
self.conv4 = nn.Sequential(
nn.Conv2d(in_channels[-2], width, 3, padding=1, bias=False),
norm_layer(width),
nn.ReLU(True))
self.conv3 = nn.Sequential(
nn.Conv2d(in_channels[-3], width, 3, padding=1, bias=False),
norm_layer(width),
nn.ReLU(True))
self.dilation1 = nn.Sequential(
SeparableConv2d(3 * width, width, 3, padding=1, dilation=1, bias=False),
norm_layer(width),
nn.ReLU(True))
self.dilation2 = nn.Sequential(
SeparableConv2d(3 * width, width, 3, padding=2, dilation=2, bias=False),
norm_layer(width),
nn.ReLU(True))
self.dilation3 = nn.Sequential(
SeparableConv2d(3 * width, width, 3, padding=4, dilation=4, bias=False),
norm_layer(width),
nn.ReLU(True))
self.dilation4 = nn.Sequential(
SeparableConv2d(3 * width, width, 3, padding=8, dilation=8, bias=False),
norm_layer(width),
nn.ReLU(True))
def forward(self, *inputs):
feats = [self.conv5(inputs[-1]), self.conv4(inputs[-2]), self.conv3(inputs[-3])]
size = feats[-1].size()[2:]
feats[-2] = F.interpolate(feats[-2], size, mode='bilinear', align_corners=True)
feats[-3] = F.interpolate(feats[-3], size, mode='bilinear', align_corners=True)
feat = torch.cat(feats, dim=1)
feat = torch.cat([self.dilation1(feat), self.dilation2(feat), self.dilation3(feat), self.dilation4(feat)],
dim=1)
return inputs[0], inputs[1], inputs[2], feat