90 lines
3.8 KiB
Python
90 lines
3.8 KiB
Python
import torch.nn as nn
|
|
import numpy as np
|
|
import torch
|
|
from .model_parts import CombinationModule
|
|
from . import resnet
|
|
|
|
class CTRBOX(nn.Module):
|
|
def __init__(self, heads, pretrained, down_ratio, final_kernel, head_conv):
|
|
super(CTRBOX, self).__init__()
|
|
|
|
#channels = [3, 64, 256, 512, 1024, 2048]
|
|
#assert down_ratio in [2, 4, 8, 16]
|
|
#self.l1 = int(np.log2(down_ratio))
|
|
#self.base_network = resnet.resnet101(pretrained=pretrained)
|
|
#self.dec_c2 = CombinationModule(512, 256, batch_norm=True)
|
|
#self.dec_c3 = CombinationModule(1024, 512, batch_norm=True)
|
|
#self.dec_c4 = CombinationModule(2048, 1024, batch_norm=True)
|
|
|
|
#channels = [3, 64, 256, 512, 1024, 2048]
|
|
#assert down_ratio in [2, 4, 8, 16]
|
|
#self.l1 = int(np.log2(down_ratio))
|
|
#self.base_network = resnet.resnet50(pretrained=pretrained)
|
|
#self.dec_c2 = CombinationModule(512, 256, batch_norm=True)
|
|
#self.dec_c3 = CombinationModule(1024, 512, batch_norm=True)
|
|
#self.dec_c4 = CombinationModule(2048, 1024, batch_norm=True)
|
|
|
|
|
|
#channels = [3, 64, 64, 128, 256, 512]
|
|
#assert down_ratio in [2, 4, 8, 16]
|
|
#self.l1 = int(np.log2(down_ratio))
|
|
#self.base_network = resnet.resnet34(pretrained=pretrained)
|
|
#self.dec_c2 = CombinationModule(128, 64, batch_norm=True)
|
|
#self.dec_c3 = CombinationModule(256, 128, batch_norm=True)
|
|
#self.dec_c4 = CombinationModule(512, 256, batch_norm=True)
|
|
|
|
channels = [3, 64, 64, 128, 256, 512]
|
|
assert down_ratio in [2, 4, 8, 16]
|
|
self.l1 = int(np.log2(down_ratio))
|
|
self.base_network = resnet.resnet18(pretrained=pretrained)
|
|
self.dec_c2 = CombinationModule(128, 64, batch_norm=True)
|
|
self.dec_c3 = CombinationModule(256, 128, batch_norm=True)
|
|
self.dec_c4 = CombinationModule(512, 256, batch_norm=True)
|
|
|
|
|
|
self.heads = heads
|
|
|
|
for head in self.heads:
|
|
classes = self.heads[head]
|
|
if head == 'wh':
|
|
fc = nn.Sequential(nn.Conv2d(channels[self.l1], head_conv, kernel_size=3, padding=1, bias=True),
|
|
# nn.BatchNorm2d(head_conv), # BN not used in the paper, but would help stable training
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(head_conv, classes, kernel_size=3, padding=1, bias=True))
|
|
else:
|
|
fc = nn.Sequential(nn.Conv2d(channels[self.l1], head_conv, kernel_size=3, padding=1, bias=True),
|
|
# nn.BatchNorm2d(head_conv), # BN not used in the paper, but would help stable training
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(head_conv, classes, kernel_size=final_kernel, stride=1, padding=final_kernel // 2, bias=True))
|
|
if 'hm' in head:
|
|
fc[-1].bias.data.fill_(-2.19)
|
|
else:
|
|
self.fill_fc_weights(fc)
|
|
|
|
self.__setattr__(head, fc)
|
|
|
|
|
|
def fill_fc_weights(self, m):
|
|
if isinstance(m, nn.Conv2d):
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x):
|
|
x = self.base_network(x)
|
|
# import matplotlib.pyplot as plt
|
|
# import os
|
|
# for idx in range(x[1].shape[1]):
|
|
# temp = x[1][0,idx,:,:]
|
|
# temp = temp.data.cpu().numpy()
|
|
# plt.imsave(os.path.join('dilation', '{}.png'.format(idx)), temp)
|
|
c4_combine = self.dec_c4(x[-1], x[-2])
|
|
c3_combine = self.dec_c3(c4_combine, x[-3])
|
|
c2_combine = self.dec_c2(c3_combine, x[-4])
|
|
|
|
dec_dict = {}
|
|
for head in self.heads:
|
|
dec_dict[head] = self.__getattr__(head)(c2_combine)
|
|
if 'hm' in head or 'cls' in head:
|
|
dec_dict[head] = torch.sigmoid(dec_dict[head])
|
|
return dec_dict
|