"""Basic Module for Semantic Segmentation""" import torch import torch.nn as nn import torch.nn.functional as F __all__ = ['_ConvBNPReLU', '_ConvBN', '_BNPReLU', '_ConvBNReLU', '_DepthwiseConv', 'InvertedResidual'] class _ConvBNReLU(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu6=False, norm_layer=nn.BatchNorm2d, **kwargs): super(_ConvBNReLU, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) self.bn = norm_layer(out_channels) self.relu = nn.ReLU6(False) if relu6 else nn.ReLU(False) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x class _ConvBNPReLU(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, norm_layer=nn.BatchNorm2d, **kwargs): super(_ConvBNPReLU, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) self.bn = norm_layer(out_channels) self.prelu = nn.PReLU(out_channels) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.prelu(x) return x class _ConvBN(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, norm_layer=nn.BatchNorm2d, **kwargs): super(_ConvBN, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) self.bn = norm_layer(out_channels) def forward(self, x): x = self.conv(x) x = self.bn(x) return x class _BNPReLU(nn.Module): def __init__(self, out_channels, norm_layer=nn.BatchNorm2d, **kwargs): super(_BNPReLU, self).__init__() self.bn = norm_layer(out_channels) self.prelu = nn.PReLU(out_channels) def forward(self, x): x = self.bn(x) x = self.prelu(x) return x # ----------------------------------------------------------------- # For PSPNet # ----------------------------------------------------------------- class _PSPModule(nn.Module): def __init__(self, in_channels, sizes=(1, 2, 3, 6), **kwargs): super(_PSPModule, self).__init__() out_channels = int(in_channels / 4) self.avgpools = nn.ModuleList() self.convs = nn.ModuleList() for size in sizes: self.avgpool.append(nn.AdaptiveAvgPool2d(size)) self.convs.append(_ConvBNReLU(in_channels, out_channels, 1, **kwargs)) def forward(self, x): size = x.size()[2:] feats = [x] for (avgpool, conv) in enumerate(zip(self.avgpools, self.convs)): feats.append(F.interpolate(conv(avgpool(x)), size, mode='bilinear', align_corners=True)) return torch.cat(feats, dim=1) # ----------------------------------------------------------------- # For MobileNet # ----------------------------------------------------------------- class _DepthwiseConv(nn.Module): """conv_dw in MobileNet""" def __init__(self, in_channels, out_channels, stride, norm_layer=nn.BatchNorm2d, **kwargs): super(_DepthwiseConv, self).__init__() self.conv = nn.Sequential( _ConvBNReLU(in_channels, in_channels, 3, stride, 1, groups=in_channels, norm_layer=norm_layer), _ConvBNReLU(in_channels, out_channels, 1, norm_layer=norm_layer)) def forward(self, x): return self.conv(x) # ----------------------------------------------------------------- # For MobileNetV2 # ----------------------------------------------------------------- class InvertedResidual(nn.Module): def __init__(self, in_channels, out_channels, stride, expand_ratio, norm_layer=nn.BatchNorm2d, **kwargs): super(InvertedResidual, self).__init__() assert stride in [1, 2] self.use_res_connect = stride == 1 and in_channels == out_channels layers = list() inter_channels = int(round(in_channels * expand_ratio)) if expand_ratio != 1: # pw layers.append(_ConvBNReLU(in_channels, inter_channels, 1, relu6=True, norm_layer=norm_layer)) layers.extend([ # dw _ConvBNReLU(inter_channels, inter_channels, 3, stride, 1, groups=inter_channels, relu6=True, norm_layer=norm_layer), # pw-linear nn.Conv2d(inter_channels, out_channels, 1, bias=False), norm_layer(out_channels)]) self.conv = nn.Sequential(*layers) def forward(self, x): if self.use_res_connect: return x + self.conv(x) else: return self.conv(x) if __name__ == '__main__': x = torch.randn(1, 32, 64, 64) model = InvertedResidual(32, 64, 2, 1) out = model(x)