AIlib2/ocrUtils/pth2onnx.py

49 lines
1.7 KiB
Python

import crnn_model.vgg_model as vgg
import sys
from ocrTrt import toONNX,ONNXtoTrt
from collections import OrderedDict
import torch
import argparse
def crnnModel(opt):
input_height=opt.mHeight
input_width=opt.mWidth
##生成识别模型
device='cuda:0'
model_path = opt.weights
recog_network, network_params = 'generation2', {'input_channel': 1, 'output_channel': 256, 'hidden_size': 256,'input_height':input_height}
num_class= 97
model = vgg.Model(num_class=num_class, **network_params)
##修正模型参数的名字
state_dict = torch.load(model_path,map_location=device)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)
model = model.to(device)
model.load_state_dict(new_state_dict)
return model
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='english_g2.onnx', help='model path(s)')
parser.add_argument('--mWidth', type=int, default=640, help='segmodel mWdith')
parser.add_argument('--mHeight', type=int, default=360, help='segmodel mHeight')
opt = parser.parse_args()
pthmodel = crnnModel(opt)
###转换TRT模型
onnxFile=opt.weights.replace('.pth','_%dX%d.onnx'%(opt.mWidth,opt.mHeight))
trtFile=opt.weights.replace('.pth','_%dX%d.engine'%(opt.mWidth,opt.mHeight))
print('#'*20, ' begin to toONNX')
toONNX(pthmodel,onnxFile,inputShape=(1,1,opt.mHeight, opt.mWidth),device='cuda:0')
print('#'*20, ' begin to TRT')
ONNXtoTrt(onnxFile,trtFile,half=False)