|
- from collections import OrderedDict
-
- import torch
- import torch.nn as nn
-
- from .bn import ABN
-
-
- class DenseModule(nn.Module):
- def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1):
- super(DenseModule, self).__init__()
- self.in_channels = in_channels
- self.growth = growth
- self.layers = layers
-
- self.convs1 = nn.ModuleList()
- self.convs3 = nn.ModuleList()
- for i in range(self.layers):
- self.convs1.append(nn.Sequential(OrderedDict([
- ("bn", norm_act(in_channels)),
- ("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False))
- ])))
- self.convs3.append(nn.Sequential(OrderedDict([
- ("bn", norm_act(self.growth * bottleneck_factor)),
- ("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False,
- dilation=dilation))
- ])))
- in_channels += self.growth
-
- @property
- def out_channels(self):
- return self.in_channels + self.growth * self.layers
-
- def forward(self, x):
- inputs = [x]
- for i in range(self.layers):
- x = torch.cat(inputs, dim=1)
- x = self.convs1[i](x)
- x = self.convs3[i](x)
- inputs += [x]
-
- return torch.cat(inputs, dim=1)
|