""" Codes of LinkNet based on https://github.com/snakers4/spacenet-three """ import torch import torch.nn as nn from torch.autograd import Variable from torchvision import models import torch.nn.functional as F from functools import partial nonlinearity = partial(F.relu,inplace=True) class Dblock_more_dilate(nn.Module): def __init__(self,channel): super(Dblock_more_dilate, self).__init__() self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2) self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4) self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8) self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16) for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): if m.bias is not None: m.bias.data.zero_() def forward(self, x): dilate1_out = nonlinearity(self.dilate1(x)) dilate2_out = nonlinearity(self.dilate2(dilate1_out)) dilate3_out = nonlinearity(self.dilate3(dilate2_out)) dilate4_out = nonlinearity(self.dilate4(dilate3_out)) dilate5_out = nonlinearity(self.dilate5(dilate4_out)) out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out + dilate5_out return out class Dblock(nn.Module): def __init__(self,channel): super(Dblock, self).__init__() self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2) self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4) self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8) #self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16) for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): if m.bias is not None: m.bias.data.zero_() def forward(self, x): dilate1_out = nonlinearity(self.dilate1(x)) dilate2_out = nonlinearity(self.dilate2(dilate1_out)) dilate3_out = nonlinearity(self.dilate3(dilate2_out)) dilate4_out = nonlinearity(self.dilate4(dilate3_out)) #dilate5_out = nonlinearity(self.dilate5(dilate4_out)) out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out# + dilate5_out return out class DecoderBlock(nn.Module): def __init__(self, in_channels, n_filters): super(DecoderBlock,self).__init__() self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1) self.norm1 = nn.BatchNorm2d(in_channels // 4) self.relu1 = nonlinearity self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1) self.norm2 = nn.BatchNorm2d(in_channels // 4) self.relu2 = nonlinearity self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1) self.norm3 = nn.BatchNorm2d(n_filters) self.relu3 = nonlinearity def forward(self, x): x = self.conv1(x) x = self.norm1(x) x = self.relu1(x) x = self.deconv2(x) x = self.norm2(x) x = self.relu2(x) x = self.conv3(x) x = self.norm3(x) x = self.relu3(x) return x class DinkNet34_less_pool(nn.Module): def __init__(self, num_classes=1): super(DinkNet34_more_dilate, self).__init__() filters = [64, 128, 256, 512] resnet = models.resnet34(pretrained=True) self.firstconv = resnet.conv1 self.firstbn = resnet.bn1 self.firstrelu = resnet.relu self.firstmaxpool = resnet.maxpool self.encoder1 = resnet.layer1 self.encoder2 = resnet.layer2 self.encoder3 = resnet.layer3 self.dblock = Dblock_more_dilate(256) self.decoder3 = DecoderBlock(filters[2], filters[1]) self.decoder2 = DecoderBlock(filters[1], filters[0]) self.decoder1 = DecoderBlock(filters[0], filters[0]) self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1) self.finalrelu1 = nonlinearity self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1) self.finalrelu2 = nonlinearity self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1) def forward(self, x): # Encoder x = self.firstconv(x) x = self.firstbn(x) x = self.firstrelu(x) x = self.firstmaxpool(x) e1 = self.encoder1(x) e2 = self.encoder2(e1) e3 = self.encoder3(e2) #Center e3 = self.dblock(e3) # Decoder d3 = self.decoder3(e3) + e2 d2 = self.decoder2(d3) + e1 d1 = self.decoder1(d2) # Final Classification out = self.finaldeconv1(d1) out = self.finalrelu1(out) out = self.finalconv2(out) out = self.finalrelu2(out) out = self.finalconv3(out) #return F.sigmoid(out) return out class DinkNet34(nn.Module): def __init__(self, num_classes=1, num_channels=3): super(DinkNet34, self).__init__() filters = [64, 128, 256, 512] resnet = models.resnet34(pretrained=True) self.firstconv = resnet.conv1 self.firstbn = resnet.bn1 self.firstrelu = resnet.relu self.firstmaxpool = resnet.maxpool self.encoder1 = resnet.layer1 self.encoder2 = resnet.layer2 self.encoder3 = resnet.layer3 self.encoder4 = resnet.layer4 self.dblock = Dblock(512) self.decoder4 = DecoderBlock(filters[3], filters[2]) self.decoder3 = DecoderBlock(filters[2], filters[1]) self.decoder2 = DecoderBlock(filters[1], filters[0]) self.decoder1 = DecoderBlock(filters[0], filters[0]) self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1) self.finalrelu1 = nonlinearity self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1) self.finalrelu2 = nonlinearity self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1) def forward(self, x): # Encoder x = self.firstconv(x) x = self.firstbn(x) x = self.firstrelu(x) x = self.firstmaxpool(x) e1 = self.encoder1(x) e2 = self.encoder2(e1) e3 = self.encoder3(e2) e4 = self.encoder4(e3) # Center e4 = self.dblock(e4) # Decoder d4 = self.decoder4(e4) + e3 d3 = self.decoder3(d4) + e2 d2 = self.decoder2(d3) + e1 d1 = self.decoder1(d2) out = self.finaldeconv1(d1) out = self.finalrelu1(out) out = self.finalconv2(out) out = self.finalrelu2(out) out = self.finalconv3(out) #return F.sigmoid(out) return out class DinkNet50(nn.Module): def __init__(self, num_classes=1): super(DinkNet50, self).__init__() filters = [256, 512, 1024, 2048] resnet = models.resnet50(pretrained=True) self.firstconv = resnet.conv1 self.firstbn = resnet.bn1 self.firstrelu = resnet.relu self.firstmaxpool = resnet.maxpool self.encoder1 = resnet.layer1 self.encoder2 = resnet.layer2 self.encoder3 = resnet.layer3 self.encoder4 = resnet.layer4 self.dblock = Dblock_more_dilate(2048) self.decoder4 = DecoderBlock(filters[3], filters[2]) self.decoder3 = DecoderBlock(filters[2], filters[1]) self.decoder2 = DecoderBlock(filters[1], filters[0]) self.decoder1 = DecoderBlock(filters[0], filters[0]) self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1) self.finalrelu1 = nonlinearity self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1) self.finalrelu2 = nonlinearity self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1) def forward(self, x): # Encoder x = self.firstconv(x) x = self.firstbn(x) x = self.firstrelu(x) x = self.firstmaxpool(x) e1 = self.encoder1(x) e2 = self.encoder2(e1) e3 = self.encoder3(e2) e4 = self.encoder4(e3) # Center e4 = self.dblock(e4) # Decoder d4 = self.decoder4(e4) + e3 d3 = self.decoder3(d4) + e2 d2 = self.decoder2(d3) + e1 d1 = self.decoder1(d2) out = self.finaldeconv1(d1) out = self.finalrelu1(out) out = self.finalconv2(out) out = self.finalrelu2(out) out = self.finalconv3(out) #return F.sigmoid(out) return out class DinkNet101(nn.Module): def __init__(self, num_classes=1): super(DinkNet101, self).__init__() filters = [256, 512, 1024, 2048] resnet = models.resnet101(pretrained=True) self.firstconv = resnet.conv1 self.firstbn = resnet.bn1 self.firstrelu = resnet.relu self.firstmaxpool = resnet.maxpool self.encoder1 = resnet.layer1 self.encoder2 = resnet.layer2 self.encoder3 = resnet.layer3 self.encoder4 = resnet.layer4 self.dblock = Dblock_more_dilate(2048) self.decoder4 = DecoderBlock(filters[3], filters[2]) self.decoder3 = DecoderBlock(filters[2], filters[1]) self.decoder2 = DecoderBlock(filters[1], filters[0]) self.decoder1 = DecoderBlock(filters[0], filters[0]) self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1) self.finalrelu1 = nonlinearity self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1) self.finalrelu2 = nonlinearity self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1) def forward(self, x): # Encoder x = self.firstconv(x) x = self.firstbn(x) x = self.firstrelu(x) x = self.firstmaxpool(x) e1 = self.encoder1(x) e2 = self.encoder2(e1) e3 = self.encoder3(e2) e4 = self.encoder4(e3) # Center e4 = self.dblock(e4) # Decoder d4 = self.decoder4(e4) + e3 d3 = self.decoder3(d4) + e2 d2 = self.decoder2(d3) + e1 d1 = self.decoder1(d2) out = self.finaldeconv1(d1) out = self.finalrelu1(out) out = self.finalconv2(out) out = self.finalrelu2(out) out = self.finalconv3(out) #return F.sigmoid(out) return out class LinkNet34(nn.Module): def __init__(self, num_classes=1): super(LinkNet34, self).__init__() filters = [64, 128, 256, 512] resnet = models.resnet34(pretrained=True) self.firstconv = resnet.conv1 self.firstbn = resnet.bn1 self.firstrelu = resnet.relu self.firstmaxpool = resnet.maxpool self.encoder1 = resnet.layer1 self.encoder2 = resnet.layer2 self.encoder3 = resnet.layer3 self.encoder4 = resnet.layer4 self.decoder4 = DecoderBlock(filters[3], filters[2]) self.decoder3 = DecoderBlock(filters[2], filters[1]) self.decoder2 = DecoderBlock(filters[1], filters[0]) self.decoder1 = DecoderBlock(filters[0], filters[0]) self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 3, stride=2) self.finalrelu1 = nonlinearity self.finalconv2 = nn.Conv2d(32, 32, 3) self.finalrelu2 = nonlinearity self.finalconv3 = nn.Conv2d(32, num_classes, 2, padding=1) def forward(self, x): # Encoder x = self.firstconv(x) x = self.firstbn(x) x = self.firstrelu(x) x = self.firstmaxpool(x) e1 = self.encoder1(x) e2 = self.encoder2(e1) e3 = self.encoder3(e2) e4 = self.encoder4(e3) # Decoder d4 = self.decoder4(e4) + e3 d3 = self.decoder3(d4) + e2 d2 = self.decoder2(d3) + e1 d1 = self.decoder1(d2) out = self.finaldeconv1(d1) out = self.finalrelu1(out) out = self.finalconv2(out) out = self.finalrelu2(out) out = self.finalconv3(out) #return F.sigmoid(out) return out