import torch.nn as nn from .modules import VGG_FeatureExtractor, BidirectionalLSTM class Model(nn.Module): def __init__(self, input_channel, output_channel, hidden_size, num_class,input_height=64): super(Model, self).__init__() """ FeatureExtraction """ self.FeatureExtraction = VGG_FeatureExtractor(input_channel, output_channel) self.FeatureExtraction_output = output_channel if input_height==64: self.AdaptiveAvgPool = nn.AvgPool2d(kernel_size=(1, 3), stride=(1,1)) elif input_height==32: self.AdaptiveAvgPool = nn.AvgPool2d(kernel_size=(1, 1), stride=(1,1)) else: self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) """ Sequence modeling""" self.SequenceModeling = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size), BidirectionalLSTM(hidden_size, hidden_size, hidden_size)) self.SequenceModeling_output = hidden_size """ Prediction """ self.Prediction = nn.Linear(self.SequenceModeling_output, num_class) def forward(self, input, text): """ Feature extraction stage """ #print('####vgg_model.py line27:',input.size(), 'input[0,0,0:2,0:2] :',input[0,0,0:2,0:2]) visual_feature = self.FeatureExtraction(input) #print('###line26:',visual_feature.size() ) visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) #print('###line29:',visual_feature.size()) visual_feature = visual_feature.squeeze(3) """ Sequence modeling stage """ contextual_feature = self.SequenceModeling(visual_feature) """ Prediction stage """ prediction = self.Prediction(contextual_feature.contiguous()) #print('###line39 vgg_model:',prediction.size()) return prediction