AIlib2/ocrUtils/crnn_model/vgg_model.py

45 lines
1.8 KiB
Python

import torch.nn as nn
from .modules import VGG_FeatureExtractor, BidirectionalLSTM
class Model(nn.Module):
def __init__(self, input_channel, output_channel, hidden_size, num_class,input_height=64):
super(Model, self).__init__()
""" FeatureExtraction """
self.FeatureExtraction = VGG_FeatureExtractor(input_channel, output_channel)
self.FeatureExtraction_output = output_channel
if input_height==64:
self.AdaptiveAvgPool = nn.AvgPool2d(kernel_size=(1, 3), stride=(1,1))
elif input_height==32:
self.AdaptiveAvgPool = nn.AvgPool2d(kernel_size=(1, 1), stride=(1,1))
else:
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1))
""" Sequence modeling"""
self.SequenceModeling = nn.Sequential(
BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size),
BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
self.SequenceModeling_output = hidden_size
""" Prediction """
self.Prediction = nn.Linear(self.SequenceModeling_output, num_class)
def forward(self, input, text):
""" Feature extraction stage """
#print('####vgg_model.py line27:',input.size(), 'input[0,0,0:2,0:2] :',input[0,0,0:2,0:2])
visual_feature = self.FeatureExtraction(input)
#print('###line26:',visual_feature.size() )
visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2))
#print('###line29:',visual_feature.size())
visual_feature = visual_feature.squeeze(3)
""" Sequence modeling stage """
contextual_feature = self.SequenceModeling(visual_feature)
""" Prediction stage """
prediction = self.Prediction(contextual_feature.contiguous())
#print('###line39 vgg_model:',prediction.size())
return prediction