73 lines
2.3 KiB
Python
73 lines
2.3 KiB
Python
#import crnn_model.vgg_model as vgg
|
|
|
|
import sys
|
|
from ocrTrt import toONNX,ONNXtoTrt
|
|
from collections import OrderedDict
|
|
import torch
|
|
import argparse
|
|
def crnnModel(opt):
|
|
input_height=opt.mHeight
|
|
input_width=opt.mWidth
|
|
mode=opt.mode.strip()
|
|
##生成识别模型
|
|
device='cuda:0'
|
|
model_path = opt.weights
|
|
|
|
if mode=='en':
|
|
import crnn_model
|
|
model = crnn_model.CRNN(32, 1, 93, 256 )
|
|
else:
|
|
import crnnCh as crnn
|
|
model = crnn.CRNN(3, 256, 7935, 32)
|
|
|
|
|
|
print('####line24:',mode)
|
|
checkpoint = torch.load(model_path)
|
|
if 'state_dict' in checkpoint.keys():
|
|
model.load_state_dict(checkpoint['state_dict'])
|
|
else:
|
|
try:
|
|
model.load_state_dict(checkpoint)
|
|
except:
|
|
##修正模型参数的名字
|
|
state_dict = torch.load(model_path)
|
|
# create new OrderedDict that does not contain `module.`
|
|
from collections import OrderedDict
|
|
new_state_dict = OrderedDict()
|
|
for k, v in state_dict.items():
|
|
name = k[7:] # remove `module.`
|
|
new_state_dict[name] = v
|
|
# load params
|
|
model.load_state_dict(new_state_dict)
|
|
|
|
|
|
model = model.to(device)
|
|
|
|
|
|
return model
|
|
|
|
|
|
if __name__=='__main__':
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--weights', type=str, default='english_g2.onnx', help='model path(s)')
|
|
parser.add_argument('--mWidth', type=int, default=640, help='segmodel mWdith')
|
|
parser.add_argument('--mHeight', type=int, default=360, help='segmodel mHeight')
|
|
parser.add_argument('--mode', type=str, default='en', help='segmodel mHeight')
|
|
|
|
opt = parser.parse_args()
|
|
|
|
pthmodel = crnnModel(opt)
|
|
|
|
###转换TRT模型
|
|
onnxFile=opt.weights.replace('.pth','_%dX%d.onnx'%(opt.mWidth,opt.mHeight))
|
|
trtFile=opt.weights.replace('.pth','_%dX%d.engine'%(opt.mWidth,opt.mHeight))
|
|
|
|
print('#'*20, ' begin to toONNX')
|
|
if opt.mode=='en':inputShape=(1,1,opt.mHeight, opt.mWidth)
|
|
else: inputShape=(1,3,opt.mHeight, opt.mWidth)
|
|
toONNX(pthmodel,onnxFile,inputShape=inputShape,device='cuda:0')
|
|
print('#'*20, ' begin to TRT')
|
|
ONNXtoTrt(onnxFile,trtFile,half=False)
|
|
|