Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

pth2onnx.py 2.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #import crnn_model.vgg_model as vgg
  2. import sys
  3. from ocrTrt import toONNX,ONNXtoTrt
  4. from collections import OrderedDict
  5. import torch
  6. import argparse
  7. def crnnModel(opt):
  8. input_height=opt.mHeight
  9. input_width=opt.mWidth
  10. mode=opt.mode.strip()
  11. ##生成识别模型
  12. device='cuda:0'
  13. model_path = opt.weights
  14. if mode=='en':
  15. import crnn_model
  16. model = crnn_model.CRNN(32, 1, 93, 256 )
  17. else:
  18. import crnnCh as crnn
  19. model = crnn.CRNN(3, 256, 7935, 32)
  20. print('####line24:',mode)
  21. checkpoint = torch.load(model_path)
  22. if 'state_dict' in checkpoint.keys():
  23. model.load_state_dict(checkpoint['state_dict'])
  24. else:
  25. try:
  26. model.load_state_dict(checkpoint)
  27. except:
  28. ##修正模型参数的名字
  29. state_dict = torch.load(model_path)
  30. # create new OrderedDict that does not contain `module.`
  31. from collections import OrderedDict
  32. new_state_dict = OrderedDict()
  33. for k, v in state_dict.items():
  34. name = k[7:] # remove `module.`
  35. new_state_dict[name] = v
  36. # load params
  37. model.load_state_dict(new_state_dict)
  38. model = model.to(device)
  39. return model
  40. if __name__=='__main__':
  41. parser = argparse.ArgumentParser()
  42. parser.add_argument('--weights', type=str, default='english_g2.onnx', help='model path(s)')
  43. parser.add_argument('--mWidth', type=int, default=640, help='segmodel mWdith')
  44. parser.add_argument('--mHeight', type=int, default=360, help='segmodel mHeight')
  45. parser.add_argument('--mode', type=str, default='en', help='segmodel mHeight')
  46. opt = parser.parse_args()
  47. pthmodel = crnnModel(opt)
  48. ###转换TRT模型
  49. onnxFile=opt.weights.replace('.pth','_%dX%d.onnx'%(opt.mWidth,opt.mHeight))
  50. trtFile=opt.weights.replace('.pth','_%dX%d.engine'%(opt.mWidth,opt.mHeight))
  51. print('#'*20, ' begin to toONNX')
  52. if opt.mode=='en':inputShape=(1,1,opt.mHeight, opt.mWidth)
  53. else: inputShape=(1,3,opt.mHeight, opt.mWidth)
  54. toONNX(pthmodel,onnxFile,inputShape=inputShape,device='cuda:0')
  55. print('#'*20, ' begin to TRT')
  56. ONNXtoTrt(onnxFile,trtFile,half=False)