AIlib2/ocrUtils2/pth2onnx.py

73 lines
2.3 KiB
Python

#import crnn_model.vgg_model as vgg
import sys
from ocrTrt import toONNX,ONNXtoTrt
from collections import OrderedDict
import torch
import argparse
def crnnModel(opt):
input_height=opt.mHeight
input_width=opt.mWidth
mode=opt.mode.strip()
##生成识别模型
device='cuda:0'
model_path = opt.weights
if mode=='en':
import crnn_model
model = crnn_model.CRNN(32, 1, 93, 256 )
else:
import crnnCh as crnn
model = crnn.CRNN(3, 256, 7935, 32)
print('####line24:',mode)
checkpoint = torch.load(model_path)
if 'state_dict' in checkpoint.keys():
model.load_state_dict(checkpoint['state_dict'])
else:
try:
model.load_state_dict(checkpoint)
except:
##修正模型参数的名字
state_dict = torch.load(model_path)
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)
model = model.to(device)
return model
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='english_g2.onnx', help='model path(s)')
parser.add_argument('--mWidth', type=int, default=640, help='segmodel mWdith')
parser.add_argument('--mHeight', type=int, default=360, help='segmodel mHeight')
parser.add_argument('--mode', type=str, default='en', help='segmodel mHeight')
opt = parser.parse_args()
pthmodel = crnnModel(opt)
###转换TRT模型
onnxFile=opt.weights.replace('.pth','_%dX%d.onnx'%(opt.mWidth,opt.mHeight))
trtFile=opt.weights.replace('.pth','_%dX%d.engine'%(opt.mWidth,opt.mHeight))
print('#'*20, ' begin to toONNX')
if opt.mode=='en':inputShape=(1,1,opt.mHeight, opt.mWidth)
else: inputShape=(1,3,opt.mHeight, opt.mWidth)
toONNX(pthmodel,onnxFile,inputShape=inputShape,device='cuda:0')
print('#'*20, ' begin to TRT')
ONNXtoTrt(onnxFile,trtFile,half=False)