|
- from collections import OrderedDict
-
- import torch.nn as nn
-
- from .bn import ABN
-
-
- class IdentityResidualBlock(nn.Module):
- def __init__(self,
- in_channels,
- channels,
- stride=1,
- dilation=1,
- groups=1,
- norm_act=ABN,
- dropout=None):
- """Configurable identity-mapping residual block
-
- Parameters
- ----------
- in_channels : int
- Number of input channels.
- channels : list of int
- Number of channels in the internal feature maps. Can either have two or three elements: if three construct
- a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then
- `3 x 3` then `1 x 1` convolutions.
- stride : int
- Stride of the first `3 x 3` convolution
- dilation : int
- Dilation to apply to the `3 x 3` convolutions.
- groups : int
- Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with
- bottleneck blocks.
- norm_act : callable
- Function to create normalization / activation Module.
- dropout: callable
- Function to create Dropout Module.
- """
- super(IdentityResidualBlock, self).__init__()
-
- # Check parameters for inconsistencies
- if len(channels) != 2 and len(channels) != 3:
- raise ValueError("channels must contain either two or three values")
- if len(channels) == 2 and groups != 1:
- raise ValueError("groups > 1 are only valid if len(channels) == 3")
-
- is_bottleneck = len(channels) == 3
- need_proj_conv = stride != 1 or in_channels != channels[-1]
-
- self.bn1 = norm_act(in_channels)
- if not is_bottleneck:
- layers = [
- ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False,
- dilation=dilation)),
- ("bn2", norm_act(channels[0])),
- ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
- dilation=dilation))
- ]
- if dropout is not None:
- layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
- else:
- layers = [
- ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)),
- ("bn2", norm_act(channels[0])),
- ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
- groups=groups, dilation=dilation)),
- ("bn3", norm_act(channels[1])),
- ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False))
- ]
- if dropout is not None:
- layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
- self.convs = nn.Sequential(OrderedDict(layers))
-
- if need_proj_conv:
- self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False)
-
- def forward(self, x):
- if hasattr(self, "proj_conv"):
- bn1 = self.bn1(x)
- shortcut = self.proj_conv(bn1)
- else:
- shortcut = x.clone()
- bn1 = self.bn1(x)
-
- out = self.convs(bn1)
- out.add_(shortcut)
-
- return out
|