135 lines
5.0 KiB
Python
135 lines
5.0 KiB
Python
|
|
"""Basic Module for Semantic Segmentation"""
|
||
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import torch.nn.functional as F
|
||
|
|
|
||
|
|
__all__ = ['_ConvBNPReLU', '_ConvBN', '_BNPReLU', '_ConvBNReLU', '_DepthwiseConv', 'InvertedResidual']
|
||
|
|
|
||
|
|
|
||
|
|
class _ConvBNReLU(nn.Module):
|
||
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
|
||
|
|
dilation=1, groups=1, relu6=False, norm_layer=nn.BatchNorm2d, **kwargs):
|
||
|
|
super(_ConvBNReLU, self).__init__()
|
||
|
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
|
||
|
|
self.bn = norm_layer(out_channels)
|
||
|
|
self.relu = nn.ReLU6(False) if relu6 else nn.ReLU(False)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
x = self.conv(x)
|
||
|
|
x = self.bn(x)
|
||
|
|
x = self.relu(x)
|
||
|
|
return x
|
||
|
|
|
||
|
|
|
||
|
|
class _ConvBNPReLU(nn.Module):
|
||
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
|
||
|
|
dilation=1, groups=1, norm_layer=nn.BatchNorm2d, **kwargs):
|
||
|
|
super(_ConvBNPReLU, self).__init__()
|
||
|
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
|
||
|
|
self.bn = norm_layer(out_channels)
|
||
|
|
self.prelu = nn.PReLU(out_channels)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
x = self.conv(x)
|
||
|
|
x = self.bn(x)
|
||
|
|
x = self.prelu(x)
|
||
|
|
return x
|
||
|
|
|
||
|
|
|
||
|
|
class _ConvBN(nn.Module):
|
||
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
|
||
|
|
dilation=1, groups=1, norm_layer=nn.BatchNorm2d, **kwargs):
|
||
|
|
super(_ConvBN, self).__init__()
|
||
|
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
|
||
|
|
self.bn = norm_layer(out_channels)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
x = self.conv(x)
|
||
|
|
x = self.bn(x)
|
||
|
|
return x
|
||
|
|
|
||
|
|
|
||
|
|
class _BNPReLU(nn.Module):
|
||
|
|
def __init__(self, out_channels, norm_layer=nn.BatchNorm2d, **kwargs):
|
||
|
|
super(_BNPReLU, self).__init__()
|
||
|
|
self.bn = norm_layer(out_channels)
|
||
|
|
self.prelu = nn.PReLU(out_channels)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
x = self.bn(x)
|
||
|
|
x = self.prelu(x)
|
||
|
|
return x
|
||
|
|
|
||
|
|
|
||
|
|
# -----------------------------------------------------------------
|
||
|
|
# For PSPNet
|
||
|
|
# -----------------------------------------------------------------
|
||
|
|
class _PSPModule(nn.Module):
|
||
|
|
def __init__(self, in_channels, sizes=(1, 2, 3, 6), **kwargs):
|
||
|
|
super(_PSPModule, self).__init__()
|
||
|
|
out_channels = int(in_channels / 4)
|
||
|
|
self.avgpools = nn.ModuleList()
|
||
|
|
self.convs = nn.ModuleList()
|
||
|
|
for size in sizes:
|
||
|
|
self.avgpool.append(nn.AdaptiveAvgPool2d(size))
|
||
|
|
self.convs.append(_ConvBNReLU(in_channels, out_channels, 1, **kwargs))
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
size = x.size()[2:]
|
||
|
|
feats = [x]
|
||
|
|
for (avgpool, conv) in enumerate(zip(self.avgpools, self.convs)):
|
||
|
|
feats.append(F.interpolate(conv(avgpool(x)), size, mode='bilinear', align_corners=True))
|
||
|
|
return torch.cat(feats, dim=1)
|
||
|
|
|
||
|
|
|
||
|
|
# -----------------------------------------------------------------
|
||
|
|
# For MobileNet
|
||
|
|
# -----------------------------------------------------------------
|
||
|
|
class _DepthwiseConv(nn.Module):
|
||
|
|
"""conv_dw in MobileNet"""
|
||
|
|
|
||
|
|
def __init__(self, in_channels, out_channels, stride, norm_layer=nn.BatchNorm2d, **kwargs):
|
||
|
|
super(_DepthwiseConv, self).__init__()
|
||
|
|
self.conv = nn.Sequential(
|
||
|
|
_ConvBNReLU(in_channels, in_channels, 3, stride, 1, groups=in_channels, norm_layer=norm_layer),
|
||
|
|
_ConvBNReLU(in_channels, out_channels, 1, norm_layer=norm_layer))
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.conv(x)
|
||
|
|
|
||
|
|
|
||
|
|
# -----------------------------------------------------------------
|
||
|
|
# For MobileNetV2
|
||
|
|
# -----------------------------------------------------------------
|
||
|
|
class InvertedResidual(nn.Module):
|
||
|
|
def __init__(self, in_channels, out_channels, stride, expand_ratio, norm_layer=nn.BatchNorm2d, **kwargs):
|
||
|
|
super(InvertedResidual, self).__init__()
|
||
|
|
assert stride in [1, 2]
|
||
|
|
self.use_res_connect = stride == 1 and in_channels == out_channels
|
||
|
|
|
||
|
|
layers = list()
|
||
|
|
inter_channels = int(round(in_channels * expand_ratio))
|
||
|
|
if expand_ratio != 1:
|
||
|
|
# pw
|
||
|
|
layers.append(_ConvBNReLU(in_channels, inter_channels, 1, relu6=True, norm_layer=norm_layer))
|
||
|
|
layers.extend([
|
||
|
|
# dw
|
||
|
|
_ConvBNReLU(inter_channels, inter_channels, 3, stride, 1,
|
||
|
|
groups=inter_channels, relu6=True, norm_layer=norm_layer),
|
||
|
|
# pw-linear
|
||
|
|
nn.Conv2d(inter_channels, out_channels, 1, bias=False),
|
||
|
|
norm_layer(out_channels)])
|
||
|
|
self.conv = nn.Sequential(*layers)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
if self.use_res_connect:
|
||
|
|
return x + self.conv(x)
|
||
|
|
else:
|
||
|
|
return self.conv(x)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
x = torch.randn(1, 32, 64, 64)
|
||
|
|
model = InvertedResidual(32, 64, 2, 1)
|
||
|
|
out = model(x)
|