|
- import crnn_model.vgg_model as vgg
- import sys
- from ocrTrt import toONNX,ONNXtoTrt
- from collections import OrderedDict
- import torch
- import argparse
- def crnnModel(opt):
- input_height=opt.mHeight
- input_width=opt.mWidth
- ##生成识别模型
- device='cuda:0'
- model_path = opt.weights
- recog_network, network_params = 'generation2', {'input_channel': 1, 'output_channel': 256, 'hidden_size': 256,'input_height':input_height}
- num_class= 97
- model = vgg.Model(num_class=num_class, **network_params)
- ##修正模型参数的名字
- state_dict = torch.load(model_path,map_location=device)
-
- new_state_dict = OrderedDict()
- for k, v in state_dict.items():
- name = k[7:] # remove `module.`
- new_state_dict[name] = v
- # load params
- model.load_state_dict(new_state_dict)
- model = model.to(device)
- model.load_state_dict(new_state_dict)
- return model
-
-
- if __name__=='__main__':
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--weights', type=str, default='english_g2.onnx', help='model path(s)')
- parser.add_argument('--mWidth', type=int, default=640, help='segmodel mWdith')
- parser.add_argument('--mHeight', type=int, default=360, help='segmodel mHeight')
- opt = parser.parse_args()
-
- pthmodel = crnnModel(opt)
-
- ###转换TRT模型
- onnxFile=opt.weights.replace('.pth','_%dX%d.onnx'%(opt.mWidth,opt.mHeight))
- trtFile=opt.weights.replace('.pth','_%dX%d.engine'%(opt.mWidth,opt.mHeight))
-
- print('#'*20, ' begin to toONNX')
- toONNX(pthmodel,onnxFile,inputShape=(1,1,opt.mHeight, opt.mWidth),device='cuda:0')
- print('#'*20, ' begin to TRT')
- ONNXtoTrt(onnxFile,trtFile,half=False)
-
|