import torch.nn as nn import numpy as np import torch from .model_parts import CombinationModule from . import resnet import decoder class CTRBOX_trt(nn.Module): def __init__(self, heads, pretrained, down_ratio, final_kernel, head_conv,test_flag=False): super(CTRBOX_trt, self).__init__() # channels = [3, 64, 256, 512, 1024, 2048] # assert down_ratio in [2, 4, 8, 16] # self.l1 = int(np.log2(down_ratio)) # self.base_network = resnet.resnet101(pretrained=pretrained) # self.dec_c2 = CombinationModule(512, 256, batch_norm=True) # self.dec_c3 = CombinationModule(1024, 512, batch_norm=True) # self.dec_c4 = CombinationModule(2048, 1024, batch_norm=True) #channels = [3, 64, 256, 512, 1024, 2048] #assert down_ratio in [2, 4, 8, 16] #self.l1 = int(np.log2(down_ratio)) #self.base_network = resnet.resnet50(pretrained=pretrained) #self.dec_c2 = CombinationModule(512, 256, batch_norm=True) #self.dec_c3 = CombinationModule(1024, 512, batch_norm=True) #self.dec_c4 = CombinationModule(2048, 1024, batch_norm=True) #channels = [3, 64, 64, 128, 256, 512] #assert down_ratio in [2, 4, 8, 16] #self.l1 = int(np.log2(down_ratio)) #self.base_network = resnet.resnet34(pretrained=pretrained) #self.dec_c2 = CombinationModule(128, 64, batch_norm=True) #self.dec_c3 = CombinationModule(256, 128, batch_norm=True) #self.dec_c4 = CombinationModule(512, 256, batch_norm=True) self.test_flag=test_flag channels = [3, 64, 64, 128, 256, 512] assert down_ratio in [2, 4, 8, 16] self.l1 = int(np.log2(down_ratio)) self.base_network = resnet.resnet18(pretrained=pretrained) self.dec_c2 = CombinationModule(128, 64, batch_norm=True) self.dec_c3 = CombinationModule(256, 128, batch_norm=True) self.dec_c4 = CombinationModule(512, 256, batch_norm=True) self.heads = heads if self.test_flag: self.decoder = decoder.DecDecoder_test(K=100, conf_thresh=0.18, num_classes=15) for head in self.heads: classes = self.heads[head] if head == 'wh': fc = nn.Sequential(nn.Conv2d(channels[self.l1], head_conv, kernel_size=3, padding=1, bias=True), # nn.BatchNorm2d(head_conv), # BN not used in the paper, but would help stable training nn.ReLU(inplace=True), nn.Conv2d(head_conv, classes, kernel_size=3, padding=1, bias=True)) else: fc = nn.Sequential(nn.Conv2d(channels[self.l1], head_conv, kernel_size=3, padding=1, bias=True), # nn.BatchNorm2d(head_conv), # BN not used in the paper, but would help stable training nn.ReLU(inplace=True), nn.Conv2d(head_conv, classes, kernel_size=final_kernel, stride=1, padding=final_kernel // 2, bias=True)) if 'hm' in head: fc[-1].bias.data.fill_(-2.19) else: self.fill_fc_weights(fc) self.__setattr__(head, fc) def fill_fc_weights(self, m): if isinstance(m, nn.Conv2d): if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x = self.base_network(x) # import matplotlib.pyplot as plt # import os # for idx in range(x[1].shape[1]): # temp = x[1][0,idx,:,:] # temp = temp.data.cpu().numpy() # plt.imsave(os.path.join('dilation', '{}.png'.format(idx)), temp) c4_combine = self.dec_c4(x[-1], x[-2]) c3_combine = self.dec_c3(c4_combine, x[-3]) c2_combine = self.dec_c2(c3_combine, x[-4]) dec_dict = {} for head in self.heads: dec_dict[head] = self.__getattr__(head)(c2_combine) if 'hm' in head or 'cls' in head: dec_dict[head] = torch.sigmoid(dec_dict[head]) predictions = self.decoder.ctdet_decode(dec_dict) #'hm': 'wh':'reg': 'cls_theta': print('###############line102#############') return predictions, dec_dict['hm'], dec_dict['wh'], dec_dict['reg'], dec_dict['cls_theta'] #if self.test_flag: # predictions = self.decoder.ctdet_decode(dec_dict) # return predictions #else: # return dec_dict class CTRBOX_pth(nn.Module): def __init__(self, heads, pretrained, down_ratio, final_kernel, head_conv,test_flag=False): super(CTRBOX_pth, self).__init__() # channels = [3, 64, 256, 512, 1024, 2048] # assert down_ratio in [2, 4, 8, 16] # self.l1 = int(np.log2(down_ratio)) # self.base_network = resnet.resnet101(pretrained=pretrained) # self.dec_c2 = CombinationModule(512, 256, batch_norm=True) # self.dec_c3 = CombinationModule(1024, 512, batch_norm=True) # self.dec_c4 = CombinationModule(2048, 1024, batch_norm=True) #channels = [3, 64, 256, 512, 1024, 2048] #assert down_ratio in [2, 4, 8, 16] #self.l1 = int(np.log2(down_ratio)) #self.base_network = resnet.resnet50(pretrained=pretrained) #self.dec_c2 = CombinationModule(512, 256, batch_norm=True) #self.dec_c3 = CombinationModule(1024, 512, batch_norm=True) #self.dec_c4 = CombinationModule(2048, 1024, batch_norm=True) #channels = [3, 64, 64, 128, 256, 512] #assert down_ratio in [2, 4, 8, 16] #self.l1 = int(np.log2(down_ratio)) #self.base_network = resnet.resnet34(pretrained=pretrained) #self.dec_c2 = CombinationModule(128, 64, batch_norm=True) #self.dec_c3 = CombinationModule(256, 128, batch_norm=True) #self.dec_c4 = CombinationModule(512, 256, batch_norm=True) self.test_flag=test_flag channels = [3, 64, 64, 128, 256, 512] assert down_ratio in [2, 4, 8, 16] self.l1 = int(np.log2(down_ratio)) self.base_network = resnet.resnet18(pretrained=pretrained) self.dec_c2 = CombinationModule(128, 64, batch_norm=True) self.dec_c3 = CombinationModule(256, 128, batch_norm=True) self.dec_c4 = CombinationModule(512, 256, batch_norm=True) self.heads = heads if self.test_flag: self.decoder = decoder.DecDecoder_test(K=100, conf_thresh=0.18, num_classes=15) for head in self.heads: classes = self.heads[head] if head == 'wh': fc = nn.Sequential(nn.Conv2d(channels[self.l1], head_conv, kernel_size=3, padding=1, bias=True), # nn.BatchNorm2d(head_conv), # BN not used in the paper, but would help stable training nn.ReLU(inplace=True), nn.Conv2d(head_conv, classes, kernel_size=3, padding=1, bias=True)) else: fc = nn.Sequential(nn.Conv2d(channels[self.l1], head_conv, kernel_size=3, padding=1, bias=True), # nn.BatchNorm2d(head_conv), # BN not used in the paper, but would help stable training nn.ReLU(inplace=True), nn.Conv2d(head_conv, classes, kernel_size=final_kernel, stride=1, padding=final_kernel // 2, bias=True)) if 'hm' in head: fc[-1].bias.data.fill_(-2.19) else: self.fill_fc_weights(fc) self.__setattr__(head, fc) def fill_fc_weights(self, m): if isinstance(m, nn.Conv2d): if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x = self.base_network(x) # import matplotlib.pyplot as plt # import os # for idx in range(x[1].shape[1]): # temp = x[1][0,idx,:,:] # temp = temp.data.cpu().numpy() # plt.imsave(os.path.join('dilation', '{}.png'.format(idx)), temp) c4_combine = self.dec_c4(x[-1], x[-2]) c3_combine = self.dec_c3(c4_combine, x[-3]) c2_combine = self.dec_c2(c3_combine, x[-4]) dec_dict = {} for head in self.heads: dec_dict[head] = self.__getattr__(head)(c2_combine) if 'hm' in head or 'cls' in head: dec_dict[head] = torch.sigmoid(dec_dict[head]) if self.test_flag: predictions = self.decoder.ctdet_decode(dec_dict) print('##line301:',predictions ) return predictions, dec_dict['hm'], dec_dict['wh'], dec_dict['reg'], dec_dict['cls_theta'] else: return dec_dict