import crnn_model.vgg_model as vgg import sys from ocrTrt import toONNX,ONNXtoTrt from collections import OrderedDict import torch import argparse def crnnModel(opt): input_height=opt.mHeight input_width=opt.mWidth ##生成识别模型 device='cuda:0' model_path = opt.weights recog_network, network_params = 'generation2', {'input_channel': 1, 'output_channel': 256, 'hidden_size': 256,'input_height':input_height} num_class= 97 model = vgg.Model(num_class=num_class, **network_params) ##修正模型参数的名字 state_dict = torch.load(model_path,map_location=device) new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v # load params model.load_state_dict(new_state_dict) model = model.to(device) model.load_state_dict(new_state_dict) return model if __name__=='__main__': parser = argparse.ArgumentParser() parser.add_argument('--weights', type=str, default='english_g2.onnx', help='model path(s)') parser.add_argument('--mWidth', type=int, default=640, help='segmodel mWdith') parser.add_argument('--mHeight', type=int, default=360, help='segmodel mHeight') opt = parser.parse_args() pthmodel = crnnModel(opt) ###转换TRT模型 onnxFile=opt.weights.replace('.pth','_%dX%d.onnx'%(opt.mWidth,opt.mHeight)) trtFile=opt.weights.replace('.pth','_%dX%d.engine'%(opt.mWidth,opt.mHeight)) print('#'*20, ' begin to toONNX') toONNX(pthmodel,onnxFile,inputShape=(1,1,opt.mHeight, opt.mWidth),device='cuda:0') print('#'*20, ' begin to TRT') ONNXtoTrt(onnxFile,trtFile,half=False)