Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

36 Zeilen
1.4KB

  1. import torch.nn as nn
  2. from .modules import ResNet_FeatureExtractor, BidirectionalLSTM
  3. class Model(nn.Module):
  4. def __init__(self, input_channel, output_channel, hidden_size, num_class):
  5. super(Model, self).__init__()
  6. """ FeatureExtraction """
  7. self.FeatureExtraction = ResNet_FeatureExtractor(input_channel, output_channel)
  8. self.FeatureExtraction_output = output_channel # int(imgH/16-1) * 512
  9. self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
  10. """ Sequence modeling"""
  11. self.SequenceModeling = nn.Sequential(
  12. BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size),
  13. BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
  14. self.SequenceModeling_output = hidden_size
  15. """ Prediction """
  16. self.Prediction = nn.Linear(self.SequenceModeling_output, num_class)
  17. def forward(self, input, text):
  18. """ Feature extraction stage """
  19. visual_feature = self.FeatureExtraction(input)
  20. visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
  21. visual_feature = visual_feature.squeeze(3)
  22. """ Sequence modeling stage """
  23. contextual_feature = self.SequenceModeling(visual_feature)
  24. """ Prediction stage """
  25. prediction = self.Prediction(contextual_feature.contiguous())
  26. return prediction