You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

205 lines
9.1KB

  1. import torch.nn as nn
  2. import numpy as np
  3. import torch
  4. from .model_parts import CombinationModule
  5. from . import resnet
  6. import decoder
  7. class CTRBOX_trt(nn.Module):
  8. def __init__(self, heads, pretrained, down_ratio, final_kernel, head_conv,test_flag=False):
  9. super(CTRBOX_trt, self).__init__()
  10. # channels = [3, 64, 256, 512, 1024, 2048]
  11. # assert down_ratio in [2, 4, 8, 16]
  12. # self.l1 = int(np.log2(down_ratio))
  13. # self.base_network = resnet.resnet101(pretrained=pretrained)
  14. # self.dec_c2 = CombinationModule(512, 256, batch_norm=True)
  15. # self.dec_c3 = CombinationModule(1024, 512, batch_norm=True)
  16. # self.dec_c4 = CombinationModule(2048, 1024, batch_norm=True)
  17. #channels = [3, 64, 256, 512, 1024, 2048]
  18. #assert down_ratio in [2, 4, 8, 16]
  19. #self.l1 = int(np.log2(down_ratio))
  20. #self.base_network = resnet.resnet50(pretrained=pretrained)
  21. #self.dec_c2 = CombinationModule(512, 256, batch_norm=True)
  22. #self.dec_c3 = CombinationModule(1024, 512, batch_norm=True)
  23. #self.dec_c4 = CombinationModule(2048, 1024, batch_norm=True)
  24. #channels = [3, 64, 64, 128, 256, 512]
  25. #assert down_ratio in [2, 4, 8, 16]
  26. #self.l1 = int(np.log2(down_ratio))
  27. #self.base_network = resnet.resnet34(pretrained=pretrained)
  28. #self.dec_c2 = CombinationModule(128, 64, batch_norm=True)
  29. #self.dec_c3 = CombinationModule(256, 128, batch_norm=True)
  30. #self.dec_c4 = CombinationModule(512, 256, batch_norm=True)
  31. self.test_flag=test_flag
  32. channels = [3, 64, 64, 128, 256, 512]
  33. assert down_ratio in [2, 4, 8, 16]
  34. self.l1 = int(np.log2(down_ratio))
  35. self.base_network = resnet.resnet18(pretrained=pretrained)
  36. self.dec_c2 = CombinationModule(128, 64, batch_norm=True)
  37. self.dec_c3 = CombinationModule(256, 128, batch_norm=True)
  38. self.dec_c4 = CombinationModule(512, 256, batch_norm=True)
  39. self.heads = heads
  40. if self.test_flag:
  41. self.decoder = decoder.DecDecoder_test(K=100,
  42. conf_thresh=0.18,
  43. num_classes=15)
  44. for head in self.heads:
  45. classes = self.heads[head]
  46. if head == 'wh':
  47. fc = nn.Sequential(nn.Conv2d(channels[self.l1], head_conv, kernel_size=3, padding=1, bias=True),
  48. # nn.BatchNorm2d(head_conv), # BN not used in the paper, but would help stable training
  49. nn.ReLU(inplace=True),
  50. nn.Conv2d(head_conv, classes, kernel_size=3, padding=1, bias=True))
  51. else:
  52. fc = nn.Sequential(nn.Conv2d(channels[self.l1], head_conv, kernel_size=3, padding=1, bias=True),
  53. # nn.BatchNorm2d(head_conv), # BN not used in the paper, but would help stable training
  54. nn.ReLU(inplace=True),
  55. nn.Conv2d(head_conv, classes, kernel_size=final_kernel, stride=1, padding=final_kernel // 2, bias=True))
  56. if 'hm' in head:
  57. fc[-1].bias.data.fill_(-2.19)
  58. else:
  59. self.fill_fc_weights(fc)
  60. self.__setattr__(head, fc)
  61. def fill_fc_weights(self, m):
  62. if isinstance(m, nn.Conv2d):
  63. if m.bias is not None:
  64. nn.init.constant_(m.bias, 0)
  65. def forward(self, x):
  66. x = self.base_network(x)
  67. # import matplotlib.pyplot as plt
  68. # import os
  69. # for idx in range(x[1].shape[1]):
  70. # temp = x[1][0,idx,:,:]
  71. # temp = temp.data.cpu().numpy()
  72. # plt.imsave(os.path.join('dilation', '{}.png'.format(idx)), temp)
  73. c4_combine = self.dec_c4(x[-1], x[-2])
  74. c3_combine = self.dec_c3(c4_combine, x[-3])
  75. c2_combine = self.dec_c2(c3_combine, x[-4])
  76. dec_dict = {}
  77. for head in self.heads:
  78. dec_dict[head] = self.__getattr__(head)(c2_combine)
  79. if 'hm' in head or 'cls' in head:
  80. dec_dict[head] = torch.sigmoid(dec_dict[head])
  81. predictions = self.decoder.ctdet_decode(dec_dict)
  82. #'hm': 'wh':'reg': 'cls_theta':
  83. print('###############line102#############')
  84. return predictions, dec_dict['hm'], dec_dict['wh'], dec_dict['reg'], dec_dict['cls_theta']
  85. #if self.test_flag:
  86. # predictions = self.decoder.ctdet_decode(dec_dict)
  87. # return predictions
  88. #else:
  89. # return dec_dict
  90. class CTRBOX_pth(nn.Module):
  91. def __init__(self, heads, pretrained, down_ratio, final_kernel, head_conv,test_flag=False):
  92. super(CTRBOX_pth, self).__init__()
  93. # channels = [3, 64, 256, 512, 1024, 2048]
  94. # assert down_ratio in [2, 4, 8, 16]
  95. # self.l1 = int(np.log2(down_ratio))
  96. # self.base_network = resnet.resnet101(pretrained=pretrained)
  97. # self.dec_c2 = CombinationModule(512, 256, batch_norm=True)
  98. # self.dec_c3 = CombinationModule(1024, 512, batch_norm=True)
  99. # self.dec_c4 = CombinationModule(2048, 1024, batch_norm=True)
  100. #channels = [3, 64, 256, 512, 1024, 2048]
  101. #assert down_ratio in [2, 4, 8, 16]
  102. #self.l1 = int(np.log2(down_ratio))
  103. #self.base_network = resnet.resnet50(pretrained=pretrained)
  104. #self.dec_c2 = CombinationModule(512, 256, batch_norm=True)
  105. #self.dec_c3 = CombinationModule(1024, 512, batch_norm=True)
  106. #self.dec_c4 = CombinationModule(2048, 1024, batch_norm=True)
  107. #channels = [3, 64, 64, 128, 256, 512]
  108. #assert down_ratio in [2, 4, 8, 16]
  109. #self.l1 = int(np.log2(down_ratio))
  110. #self.base_network = resnet.resnet34(pretrained=pretrained)
  111. #self.dec_c2 = CombinationModule(128, 64, batch_norm=True)
  112. #self.dec_c3 = CombinationModule(256, 128, batch_norm=True)
  113. #self.dec_c4 = CombinationModule(512, 256, batch_norm=True)
  114. self.test_flag=test_flag
  115. channels = [3, 64, 64, 128, 256, 512]
  116. assert down_ratio in [2, 4, 8, 16]
  117. self.l1 = int(np.log2(down_ratio))
  118. self.base_network = resnet.resnet18(pretrained=pretrained)
  119. self.dec_c2 = CombinationModule(128, 64, batch_norm=True)
  120. self.dec_c3 = CombinationModule(256, 128, batch_norm=True)
  121. self.dec_c4 = CombinationModule(512, 256, batch_norm=True)
  122. self.heads = heads
  123. if self.test_flag:
  124. self.decoder = decoder.DecDecoder_test(K=100,
  125. conf_thresh=0.18,
  126. num_classes=15)
  127. for head in self.heads:
  128. classes = self.heads[head]
  129. if head == 'wh':
  130. fc = nn.Sequential(nn.Conv2d(channels[self.l1], head_conv, kernel_size=3, padding=1, bias=True),
  131. # nn.BatchNorm2d(head_conv), # BN not used in the paper, but would help stable training
  132. nn.ReLU(inplace=True),
  133. nn.Conv2d(head_conv, classes, kernel_size=3, padding=1, bias=True))
  134. else:
  135. fc = nn.Sequential(nn.Conv2d(channels[self.l1], head_conv, kernel_size=3, padding=1, bias=True),
  136. # nn.BatchNorm2d(head_conv), # BN not used in the paper, but would help stable training
  137. nn.ReLU(inplace=True),
  138. nn.Conv2d(head_conv, classes, kernel_size=final_kernel, stride=1, padding=final_kernel // 2, bias=True))
  139. if 'hm' in head:
  140. fc[-1].bias.data.fill_(-2.19)
  141. else:
  142. self.fill_fc_weights(fc)
  143. self.__setattr__(head, fc)
  144. def fill_fc_weights(self, m):
  145. if isinstance(m, nn.Conv2d):
  146. if m.bias is not None:
  147. nn.init.constant_(m.bias, 0)
  148. def forward(self, x):
  149. x = self.base_network(x)
  150. # import matplotlib.pyplot as plt
  151. # import os
  152. # for idx in range(x[1].shape[1]):
  153. # temp = x[1][0,idx,:,:]
  154. # temp = temp.data.cpu().numpy()
  155. # plt.imsave(os.path.join('dilation', '{}.png'.format(idx)), temp)
  156. c4_combine = self.dec_c4(x[-1], x[-2])
  157. c3_combine = self.dec_c3(c4_combine, x[-3])
  158. c2_combine = self.dec_c2(c3_combine, x[-4])
  159. dec_dict = {}
  160. for head in self.heads:
  161. dec_dict[head] = self.__getattr__(head)(c2_combine)
  162. if 'hm' in head or 'cls' in head:
  163. dec_dict[head] = torch.sigmoid(dec_dict[head])
  164. if self.test_flag:
  165. predictions = self.decoder.ctdet_decode(dec_dict)
  166. print('##line301:',predictions )
  167. return predictions, dec_dict['hm'], dec_dict['wh'], dec_dict['reg'], dec_dict['cls_theta']
  168. else:
  169. return dec_dict