You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

45 lines
1.8KB

  1. import torch.nn as nn
  2. from .modules import VGG_FeatureExtractor, BidirectionalLSTM
  3. class Model(nn.Module):
  4. def __init__(self, input_channel, output_channel, hidden_size, num_class,input_height=64):
  5. super(Model, self).__init__()
  6. """ FeatureExtraction """
  7. self.FeatureExtraction = VGG_FeatureExtractor(input_channel, output_channel)
  8. self.FeatureExtraction_output = output_channel
  9. if input_height==64:
  10. self.AdaptiveAvgPool = nn.AvgPool2d(kernel_size=(1, 3), stride=(1,1))
  11. elif input_height==32:
  12. self.AdaptiveAvgPool = nn.AvgPool2d(kernel_size=(1, 1), stride=(1,1))
  13. else:
  14. self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1))
  15. """ Sequence modeling"""
  16. self.SequenceModeling = nn.Sequential(
  17. BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size),
  18. BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
  19. self.SequenceModeling_output = hidden_size
  20. """ Prediction """
  21. self.Prediction = nn.Linear(self.SequenceModeling_output, num_class)
  22. def forward(self, input, text):
  23. """ Feature extraction stage """
  24. #print('####vgg_model.py line27:',input.size(), 'input[0,0,0:2,0:2] :',input[0,0,0:2,0:2])
  25. visual_feature = self.FeatureExtraction(input)
  26. #print('###line26:',visual_feature.size() )
  27. visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2))
  28. #print('###line29:',visual_feature.size())
  29. visual_feature = visual_feature.squeeze(3)
  30. """ Sequence modeling stage """
  31. contextual_feature = self.SequenceModeling(visual_feature)
  32. """ Prediction stage """
  33. prediction = self.Prediction(contextual_feature.contiguous())
  34. #print('###line39 vgg_model:',prediction.size())
  35. return prediction