You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

126 lines
4.1KB

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class BidirectionalLSTM(nn.Module):
  5. # Inputs hidden units Out
  6. def __init__(self, nIn, nHidden, nOut):
  7. super(BidirectionalLSTM, self).__init__()
  8. self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
  9. self.embedding = nn.Linear(nHidden * 2, nOut)
  10. def forward(self, input):
  11. recurrent, _ = self.rnn(input)
  12. T, b, h = recurrent.size()
  13. t_rec = recurrent.view(T * b, h)
  14. output = self.embedding(t_rec) # [T * b, nOut]
  15. output = output.view(T, b, -1)
  16. return output
  17. class CRNN(nn.Module):
  18. def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
  19. super(CRNN, self).__init__()
  20. assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
  21. ks = [3, 3, 3, 3, 3, 3, 2]
  22. ps = [1, 1, 1, 1, 1, 1, 0]
  23. ss = [1, 1, 1, 1, 1, 1, 1]
  24. nm = [64, 128, 256, 256, 512, 512, 512]
  25. cnn = nn.Sequential()
  26. def convRelu(i, batchNormalization=False):
  27. nIn = nc if i == 0 else nm[i - 1]
  28. nOut = nm[i]
  29. cnn.add_module('conv{0}'.format(i),
  30. nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
  31. if batchNormalization:
  32. cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
  33. if leakyRelu:
  34. cnn.add_module('relu{0}'.format(i),
  35. nn.LeakyReLU(0.2, inplace=True))
  36. else:
  37. cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
  38. convRelu(0)
  39. cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
  40. convRelu(1)
  41. cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
  42. convRelu(2, True)
  43. convRelu(3)
  44. cnn.add_module('pooling{0}'.format(2),
  45. nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
  46. convRelu(4, True)
  47. convRelu(5)
  48. cnn.add_module('pooling{0}'.format(3),
  49. nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
  50. convRelu(6, True) # 512x1x16
  51. self.cnn = cnn
  52. self.rnn = nn.Sequential(
  53. BidirectionalLSTM(512, nh, nh),
  54. BidirectionalLSTM(nh, nh, nclass))
  55. def forward(self, input):
  56. # conv features
  57. conv = self.cnn(input)
  58. b, c, h, w = conv.size()
  59. #print(conv.size())
  60. assert h == 1, "the height of conv must be 1"
  61. conv = conv.squeeze(2) # b *512 * width
  62. conv = conv.permute(2, 0, 1) # [w, b, c]
  63. output = F.log_softmax(self.rnn(conv), dim=2)
  64. return output
  65. def weights_init(m):
  66. classname = m.__class__.__name__
  67. if classname.find('Conv') != -1:
  68. m.weight.data.normal_(0.0, 0.02)
  69. elif classname.find('BatchNorm') != -1:
  70. m.weight.data.normal_(1.0, 0.02)
  71. m.bias.data.fill_(0)
  72. def load_model_weights(model,weight):
  73. checkpoint = torch.load(weight)
  74. if 'state_dict' in checkpoint.keys():
  75. model.load_state_dict(checkpoint['state_dict'])
  76. else:
  77. try:
  78. model.load_state_dict(checkpoint)
  79. except:
  80. ##修正模型参数的名字
  81. state_dict = torch.load(weight)
  82. # create new OrderedDict that does not contain `module.`
  83. from collections import OrderedDict
  84. new_state_dict = OrderedDict()
  85. for k, v in state_dict.items():
  86. name = k[7:] # remove `module.`
  87. new_state_dict[name] = v
  88. # load params
  89. model.load_state_dict(new_state_dict)
  90. def get_crnn(config,weights=None):
  91. model = CRNN(config.MODEL.IMAGE_SIZE.H, 1, config.MODEL.NUM_CLASSES + 1, config.MODEL.NUM_HIDDEN)
  92. if weights:
  93. load_model_weights(model,weights)
  94. '''
  95. checkpoint = torch.load(weights)
  96. if 'state_dict' in checkpoint.keys():
  97. model.load_state_dict(checkpoint['state_dict'])
  98. else:
  99. model.load_state_dict(checkpoint)
  100. '''
  101. else:
  102. model.apply(weights_init)
  103. return model