Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

49 rindas
1.7KB

  1. import crnn_model.vgg_model as vgg
  2. import sys
  3. from ocrTrt import toONNX,ONNXtoTrt
  4. from collections import OrderedDict
  5. import torch
  6. import argparse
  7. def crnnModel(opt):
  8. input_height=opt.mHeight
  9. input_width=opt.mWidth
  10. ##生成识别模型
  11. device='cuda:0'
  12. model_path = opt.weights
  13. recog_network, network_params = 'generation2', {'input_channel': 1, 'output_channel': 256, 'hidden_size': 256,'input_height':input_height}
  14. num_class= 97
  15. model = vgg.Model(num_class=num_class, **network_params)
  16. ##修正模型参数的名字
  17. state_dict = torch.load(model_path,map_location=device)
  18. new_state_dict = OrderedDict()
  19. for k, v in state_dict.items():
  20. name = k[7:] # remove `module.`
  21. new_state_dict[name] = v
  22. # load params
  23. model.load_state_dict(new_state_dict)
  24. model = model.to(device)
  25. model.load_state_dict(new_state_dict)
  26. return model
  27. if __name__=='__main__':
  28. parser = argparse.ArgumentParser()
  29. parser.add_argument('--weights', type=str, default='english_g2.onnx', help='model path(s)')
  30. parser.add_argument('--mWidth', type=int, default=640, help='segmodel mWdith')
  31. parser.add_argument('--mHeight', type=int, default=360, help='segmodel mHeight')
  32. opt = parser.parse_args()
  33. pthmodel = crnnModel(opt)
  34. ###转换TRT模型
  35. onnxFile=opt.weights.replace('.pth','_%dX%d.onnx'%(opt.mWidth,opt.mHeight))
  36. trtFile=opt.weights.replace('.pth','_%dX%d.engine'%(opt.mWidth,opt.mHeight))
  37. print('#'*20, ' begin to toONNX')
  38. toONNX(pthmodel,onnxFile,inputShape=(1,1,opt.mHeight, opt.mWidth),device='cuda:0')
  39. print('#'*20, ' begin to TRT')
  40. ONNXtoTrt(onnxFile,trtFile,half=False)