You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

пре 1 година
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import torch.nn as nn
  2. import numpy as np
  3. import torch
  4. from .model_parts import CombinationModule
  5. from . import resnet
  6. import decoder
  7. class CTRBOX(nn.Module):
  8. def __init__(self, heads, pretrained, down_ratio, final_kernel, head_conv):
  9. super(CTRBOX, self).__init__()
  10. # channels = [3, 64, 256, 512, 1024, 2048]
  11. # assert down_ratio in [2, 4, 8, 16]
  12. # self.l1 = int(np.log2(down_ratio))
  13. # self.base_network = resnet.resnet101(pretrained=pretrained)
  14. # self.dec_c2 = CombinationModule(512, 256, batch_norm=True)
  15. # self.dec_c3 = CombinationModule(1024, 512, batch_norm=True)
  16. # self.dec_c4 = CombinationModule(2048, 1024, batch_norm=True)
  17. #channels = [3, 64, 256, 512, 1024, 2048]
  18. #assert down_ratio in [2, 4, 8, 16]
  19. #self.l1 = int(np.log2(down_ratio))
  20. #self.base_network = resnet.resnet50(pretrained=pretrained)
  21. #self.dec_c2 = CombinationModule(512, 256, batch_norm=True)
  22. #self.dec_c3 = CombinationModule(1024, 512, batch_norm=True)
  23. #self.dec_c4 = CombinationModule(2048, 1024, batch_norm=True)
  24. #channels = [3, 64, 64, 128, 256, 512]
  25. #assert down_ratio in [2, 4, 8, 16]
  26. #self.l1 = int(np.log2(down_ratio))
  27. #self.base_network = resnet.resnet34(pretrained=pretrained)
  28. #self.dec_c2 = CombinationModule(128, 64, batch_norm=True)
  29. #self.dec_c3 = CombinationModule(256, 128, batch_norm=True)
  30. #self.dec_c4 = CombinationModule(512, 256, batch_norm=True)
  31. channels = [3, 64, 64, 128, 256, 512]
  32. assert down_ratio in [2, 4, 8, 16]
  33. self.l1 = int(np.log2(down_ratio))
  34. self.base_network = resnet.resnet18(pretrained=pretrained)
  35. self.dec_c2 = CombinationModule(128, 64, batch_norm=True)
  36. self.dec_c3 = CombinationModule(256, 128, batch_norm=True)
  37. self.dec_c4 = CombinationModule(512, 256, batch_norm=True)
  38. print('#####################ctrbox_net.py ##############')
  39. self.heads = heads
  40. for head in self.heads:
  41. classes = self.heads[head]
  42. if head == 'wh':
  43. fc = nn.Sequential(nn.Conv2d(channels[self.l1], head_conv, kernel_size=3, padding=1, bias=True),
  44. # nn.BatchNorm2d(head_conv), # BN not used in the paper, but would help stable training
  45. nn.ReLU(inplace=True),
  46. nn.Conv2d(head_conv, classes, kernel_size=3, padding=1, bias=True))
  47. else:
  48. fc = nn.Sequential(nn.Conv2d(channels[self.l1], head_conv, kernel_size=3, padding=1, bias=True),
  49. # nn.BatchNorm2d(head_conv), # BN not used in the paper, but would help stable training
  50. nn.ReLU(inplace=True),
  51. nn.Conv2d(head_conv, classes, kernel_size=final_kernel, stride=1, padding=final_kernel // 2, bias=True))
  52. if 'hm' in head:
  53. fc[-1].bias.data.fill_(-2.19)
  54. else:
  55. self.fill_fc_weights(fc)
  56. self.__setattr__(head, fc)
  57. def fill_fc_weights(self, m):
  58. if isinstance(m, nn.Conv2d):
  59. if m.bias is not None:
  60. nn.init.constant_(m.bias, 0)
  61. def forward(self, x):
  62. x = self.base_network(x)
  63. # import matplotlib.pyplot as plt
  64. # import os
  65. # for idx in range(x[1].shape[1]):
  66. # temp = x[1][0,idx,:,:]
  67. # temp = temp.data.cpu().numpy()
  68. # plt.imsave(os.path.join('dilation', '{}.png'.format(idx)), temp)
  69. c4_combine = self.dec_c4(x[-1], x[-2])
  70. c3_combine = self.dec_c3(c4_combine, x[-3])
  71. c2_combine = self.dec_c2(c3_combine, x[-4])
  72. dec_dict = {}
  73. for head in self.heads:
  74. dec_dict[head] = self.__getattr__(head)(c2_combine)
  75. if 'hm' in head or 'cls' in head:
  76. dec_dict[head] = torch.sigmoid(dec_dict[head])
  77. return dec_dict