49 lines
1.7 KiB
Python
49 lines
1.7 KiB
Python
import crnn_model.vgg_model as vgg
|
|
import sys
|
|
from ocrTrt import toONNX,ONNXtoTrt
|
|
from collections import OrderedDict
|
|
import torch
|
|
import argparse
|
|
def crnnModel(opt):
|
|
input_height=opt.mHeight
|
|
input_width=opt.mWidth
|
|
##生成识别模型
|
|
device='cuda:0'
|
|
model_path = opt.weights
|
|
recog_network, network_params = 'generation2', {'input_channel': 1, 'output_channel': 256, 'hidden_size': 256,'input_height':input_height}
|
|
num_class= 97
|
|
model = vgg.Model(num_class=num_class, **network_params)
|
|
##修正模型参数的名字
|
|
state_dict = torch.load(model_path,map_location=device)
|
|
|
|
new_state_dict = OrderedDict()
|
|
for k, v in state_dict.items():
|
|
name = k[7:] # remove `module.`
|
|
new_state_dict[name] = v
|
|
# load params
|
|
model.load_state_dict(new_state_dict)
|
|
model = model.to(device)
|
|
model.load_state_dict(new_state_dict)
|
|
return model
|
|
|
|
|
|
if __name__=='__main__':
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--weights', type=str, default='english_g2.onnx', help='model path(s)')
|
|
parser.add_argument('--mWidth', type=int, default=640, help='segmodel mWdith')
|
|
parser.add_argument('--mHeight', type=int, default=360, help='segmodel mHeight')
|
|
opt = parser.parse_args()
|
|
|
|
pthmodel = crnnModel(opt)
|
|
|
|
###转换TRT模型
|
|
onnxFile=opt.weights.replace('.pth','_%dX%d.onnx'%(opt.mWidth,opt.mHeight))
|
|
trtFile=opt.weights.replace('.pth','_%dX%d.engine'%(opt.mWidth,opt.mHeight))
|
|
|
|
print('#'*20, ' begin to toONNX')
|
|
toONNX(pthmodel,onnxFile,inputShape=(1,1,opt.mHeight, opt.mWidth),device='cuda:0')
|
|
print('#'*20, ' begin to TRT')
|
|
ONNXtoTrt(onnxFile,trtFile,half=False)
|
|
|