126 lines
4.1 KiB
Python
126 lines
4.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class BidirectionalLSTM(nn.Module):
|
|
# Inputs hidden units Out
|
|
def __init__(self, nIn, nHidden, nOut):
|
|
super(BidirectionalLSTM, self).__init__()
|
|
|
|
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
|
|
self.embedding = nn.Linear(nHidden * 2, nOut)
|
|
|
|
def forward(self, input):
|
|
recurrent, _ = self.rnn(input)
|
|
T, b, h = recurrent.size()
|
|
t_rec = recurrent.view(T * b, h)
|
|
|
|
output = self.embedding(t_rec) # [T * b, nOut]
|
|
output = output.view(T, b, -1)
|
|
|
|
return output
|
|
|
|
class CRNN(nn.Module):
|
|
def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
|
|
super(CRNN, self).__init__()
|
|
|
|
assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
|
|
|
|
ks = [3, 3, 3, 3, 3, 3, 2]
|
|
ps = [1, 1, 1, 1, 1, 1, 0]
|
|
ss = [1, 1, 1, 1, 1, 1, 1]
|
|
nm = [64, 128, 256, 256, 512, 512, 512]
|
|
|
|
cnn = nn.Sequential()
|
|
|
|
def convRelu(i, batchNormalization=False):
|
|
nIn = nc if i == 0 else nm[i - 1]
|
|
nOut = nm[i]
|
|
cnn.add_module('conv{0}'.format(i),
|
|
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
|
|
if batchNormalization:
|
|
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
|
|
if leakyRelu:
|
|
cnn.add_module('relu{0}'.format(i),
|
|
nn.LeakyReLU(0.2, inplace=True))
|
|
else:
|
|
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
|
|
|
|
convRelu(0)
|
|
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
|
|
convRelu(1)
|
|
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
|
|
convRelu(2, True)
|
|
convRelu(3)
|
|
cnn.add_module('pooling{0}'.format(2),
|
|
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
|
|
convRelu(4, True)
|
|
convRelu(5)
|
|
cnn.add_module('pooling{0}'.format(3),
|
|
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
|
|
convRelu(6, True) # 512x1x16
|
|
|
|
self.cnn = cnn
|
|
self.rnn = nn.Sequential(
|
|
BidirectionalLSTM(512, nh, nh),
|
|
BidirectionalLSTM(nh, nh, nclass))
|
|
|
|
def forward(self, input):
|
|
|
|
# conv features
|
|
conv = self.cnn(input)
|
|
b, c, h, w = conv.size()
|
|
#print(conv.size())
|
|
assert h == 1, "the height of conv must be 1"
|
|
conv = conv.squeeze(2) # b *512 * width
|
|
conv = conv.permute(2, 0, 1) # [w, b, c]
|
|
output = F.log_softmax(self.rnn(conv), dim=2)
|
|
|
|
return output
|
|
|
|
def weights_init(m):
|
|
classname = m.__class__.__name__
|
|
if classname.find('Conv') != -1:
|
|
m.weight.data.normal_(0.0, 0.02)
|
|
elif classname.find('BatchNorm') != -1:
|
|
m.weight.data.normal_(1.0, 0.02)
|
|
m.bias.data.fill_(0)
|
|
|
|
def load_model_weights(model,weight):
|
|
|
|
checkpoint = torch.load(weight)
|
|
if 'state_dict' in checkpoint.keys():
|
|
model.load_state_dict(checkpoint['state_dict'])
|
|
else:
|
|
try:
|
|
model.load_state_dict(checkpoint)
|
|
except:
|
|
##修正模型参数的名字
|
|
state_dict = torch.load(weight)
|
|
# create new OrderedDict that does not contain `module.`
|
|
from collections import OrderedDict
|
|
new_state_dict = OrderedDict()
|
|
for k, v in state_dict.items():
|
|
name = k[7:] # remove `module.`
|
|
new_state_dict[name] = v
|
|
# load params
|
|
model.load_state_dict(new_state_dict)
|
|
|
|
def get_crnn(config,weights=None):
|
|
|
|
model = CRNN(config.MODEL.IMAGE_SIZE.H, 1, config.MODEL.NUM_CLASSES + 1, config.MODEL.NUM_HIDDEN)
|
|
|
|
if weights:
|
|
load_model_weights(model,weights)
|
|
'''
|
|
checkpoint = torch.load(weights)
|
|
if 'state_dict' in checkpoint.keys():
|
|
model.load_state_dict(checkpoint['state_dict'])
|
|
else:
|
|
model.load_state_dict(checkpoint)
|
|
'''
|
|
else:
|
|
model.apply(weights_init)
|
|
|
|
return model
|