#import crnn_model.vgg_model as vgg import sys from ocrTrt import toONNX,ONNXtoTrt from collections import OrderedDict import torch import argparse def crnnModel(opt): input_height=opt.mHeight input_width=opt.mWidth mode=opt.mode.strip() ##生成识别模型 device='cuda:0' model_path = opt.weights if mode=='en': import crnn_model model = crnn_model.CRNN(32, 1, 93, 256 ) else: import crnnCh as crnn model = crnn.CRNN(3, 256, 7935, 32) print('####line24:',mode) checkpoint = torch.load(model_path) if 'state_dict' in checkpoint.keys(): model.load_state_dict(checkpoint['state_dict']) else: try: model.load_state_dict(checkpoint) except: ##修正模型参数的名字 state_dict = torch.load(model_path) # create new OrderedDict that does not contain `module.` from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v # load params model.load_state_dict(new_state_dict) model = model.to(device) return model if __name__=='__main__': parser = argparse.ArgumentParser() parser.add_argument('--weights', type=str, default='english_g2.onnx', help='model path(s)') parser.add_argument('--mWidth', type=int, default=640, help='segmodel mWdith') parser.add_argument('--mHeight', type=int, default=360, help='segmodel mHeight') parser.add_argument('--mode', type=str, default='en', help='segmodel mHeight') opt = parser.parse_args() pthmodel = crnnModel(opt) ###转换TRT模型 onnxFile=opt.weights.replace('.pth','_%dX%d.onnx'%(opt.mWidth,opt.mHeight)) trtFile=opt.weights.replace('.pth','_%dX%d.engine'%(opt.mWidth,opt.mHeight)) print('#'*20, ' begin to toONNX') if opt.mode=='en':inputShape=(1,1,opt.mHeight, opt.mWidth) else: inputShape=(1,3,opt.mHeight, opt.mWidth) toONNX(pthmodel,onnxFile,inputShape=inputShape,device='cuda:0') print('#'*20, ' begin to TRT') ONNXtoTrt(onnxFile,trtFile,half=False)