|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- #import crnn_model.vgg_model as vgg
-
- import sys
- from ocrTrt import toONNX,ONNXtoTrt
- from collections import OrderedDict
- import torch
- import argparse
- def crnnModel(opt):
- input_height=opt.mHeight
- input_width=opt.mWidth
- mode=opt.mode.strip()
- ##生成识别模型
- device='cuda:0'
- model_path = opt.weights
-
- if mode=='en':
- import crnn_model
- model = crnn_model.CRNN(32, 1, 93, 256 )
- else:
- import crnnCh as crnn
- model = crnn.CRNN(3, 256, 7935, 32)
-
-
- print('####line24:',mode)
- checkpoint = torch.load(model_path)
- if 'state_dict' in checkpoint.keys():
- model.load_state_dict(checkpoint['state_dict'])
- else:
- try:
- model.load_state_dict(checkpoint)
- except:
- ##修正模型参数的名字
- state_dict = torch.load(model_path)
- # create new OrderedDict that does not contain `module.`
- from collections import OrderedDict
- new_state_dict = OrderedDict()
- for k, v in state_dict.items():
- name = k[7:] # remove `module.`
- new_state_dict[name] = v
- # load params
- model.load_state_dict(new_state_dict)
-
-
- model = model.to(device)
-
-
- return model
-
-
- if __name__=='__main__':
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--weights', type=str, default='english_g2.onnx', help='model path(s)')
- parser.add_argument('--mWidth', type=int, default=640, help='segmodel mWdith')
- parser.add_argument('--mHeight', type=int, default=360, help='segmodel mHeight')
- parser.add_argument('--mode', type=str, default='en', help='segmodel mHeight')
-
- opt = parser.parse_args()
-
- pthmodel = crnnModel(opt)
-
- ###转换TRT模型
- onnxFile=opt.weights.replace('.pth','_%dX%d.onnx'%(opt.mWidth,opt.mHeight))
- trtFile=opt.weights.replace('.pth','_%dX%d.engine'%(opt.mWidth,opt.mHeight))
-
- print('#'*20, ' begin to toONNX')
- if opt.mode=='en':inputShape=(1,1,opt.mHeight, opt.mWidth)
- else: inputShape=(1,3,opt.mHeight, opt.mWidth)
- toONNX(pthmodel,onnxFile,inputShape=inputShape,device='cuda:0')
- print('#'*20, ' begin to TRT')
- ONNXtoTrt(onnxFile,trtFile,half=False)
-
|