Ship_Tilt_Detection/models/ctrbox_net.py

90 lines
3.8 KiB
Python

import torch.nn as nn
import numpy as np
import torch
from .model_parts import CombinationModule
from . import resnet
class CTRBOX(nn.Module):
def __init__(self, heads, pretrained, down_ratio, final_kernel, head_conv):
super(CTRBOX, self).__init__()
#channels = [3, 64, 256, 512, 1024, 2048]
#assert down_ratio in [2, 4, 8, 16]
#self.l1 = int(np.log2(down_ratio))
#self.base_network = resnet.resnet101(pretrained=pretrained)
#self.dec_c2 = CombinationModule(512, 256, batch_norm=True)
#self.dec_c3 = CombinationModule(1024, 512, batch_norm=True)
#self.dec_c4 = CombinationModule(2048, 1024, batch_norm=True)
#channels = [3, 64, 256, 512, 1024, 2048]
#assert down_ratio in [2, 4, 8, 16]
#self.l1 = int(np.log2(down_ratio))
#self.base_network = resnet.resnet50(pretrained=pretrained)
#self.dec_c2 = CombinationModule(512, 256, batch_norm=True)
#self.dec_c3 = CombinationModule(1024, 512, batch_norm=True)
#self.dec_c4 = CombinationModule(2048, 1024, batch_norm=True)
#channels = [3, 64, 64, 128, 256, 512]
#assert down_ratio in [2, 4, 8, 16]
#self.l1 = int(np.log2(down_ratio))
#self.base_network = resnet.resnet34(pretrained=pretrained)
#self.dec_c2 = CombinationModule(128, 64, batch_norm=True)
#self.dec_c3 = CombinationModule(256, 128, batch_norm=True)
#self.dec_c4 = CombinationModule(512, 256, batch_norm=True)
channels = [3, 64, 64, 128, 256, 512]
assert down_ratio in [2, 4, 8, 16]
self.l1 = int(np.log2(down_ratio))
self.base_network = resnet.resnet18(pretrained=pretrained)
self.dec_c2 = CombinationModule(128, 64, batch_norm=True)
self.dec_c3 = CombinationModule(256, 128, batch_norm=True)
self.dec_c4 = CombinationModule(512, 256, batch_norm=True)
self.heads = heads
for head in self.heads:
classes = self.heads[head]
if head == 'wh':
fc = nn.Sequential(nn.Conv2d(channels[self.l1], head_conv, kernel_size=3, padding=1, bias=True),
# nn.BatchNorm2d(head_conv), # BN not used in the paper, but would help stable training
nn.ReLU(inplace=True),
nn.Conv2d(head_conv, classes, kernel_size=3, padding=1, bias=True))
else:
fc = nn.Sequential(nn.Conv2d(channels[self.l1], head_conv, kernel_size=3, padding=1, bias=True),
# nn.BatchNorm2d(head_conv), # BN not used in the paper, but would help stable training
nn.ReLU(inplace=True),
nn.Conv2d(head_conv, classes, kernel_size=final_kernel, stride=1, padding=final_kernel // 2, bias=True))
if 'hm' in head:
fc[-1].bias.data.fill_(-2.19)
else:
self.fill_fc_weights(fc)
self.__setattr__(head, fc)
def fill_fc_weights(self, m):
if isinstance(m, nn.Conv2d):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.base_network(x)
# import matplotlib.pyplot as plt
# import os
# for idx in range(x[1].shape[1]):
# temp = x[1][0,idx,:,:]
# temp = temp.data.cpu().numpy()
# plt.imsave(os.path.join('dilation', '{}.png'.format(idx)), temp)
c4_combine = self.dec_c4(x[-1], x[-2])
c3_combine = self.dec_c3(c4_combine, x[-3])
c2_combine = self.dec_c2(c3_combine, x[-4])
dec_dict = {}
for head in self.heads:
dec_dict[head] = self.__getattr__(head)(c2_combine)
if 'hm' in head or 'cls' in head:
dec_dict[head] = torch.sigmoid(dec_dict[head])
return dec_dict