Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

62 Zeilen
2.0KB

  1. #import crnn_model.vgg_model as vgg
  2. import crnn_model
  3. import sys
  4. from ocrTrt import toONNX,ONNXtoTrt
  5. from collections import OrderedDict
  6. import torch
  7. import argparse
  8. def crnnModel(opt):
  9. input_height=opt.mHeight
  10. input_width=opt.mWidth
  11. ##生成识别模型
  12. device='cuda:0'
  13. model_path = opt.weights
  14. recog_network, network_params = 'generation2', {'input_channel': 1, 'output_channel': 256, 'hidden_size': 256,'input_height':input_height}
  15. num_class= 97
  16. model = crnn_model.CRNN(32, 1, 93, 256 )
  17. checkpoint = torch.load(model_path)
  18. if 'state_dict' in checkpoint.keys():
  19. model.load_state_dict(checkpoint['state_dict'])
  20. else:
  21. model.load_state_dict(checkpoint)
  22. model = model.to(device)
  23. #model = vgg.Model(num_class=num_class, **network_params)
  24. ##修正模型参数的名字
  25. #state_dict = torch.load(model_path,map_location=device)
  26. #new_state_dict = OrderedDict()
  27. #for k, v in state_dict.items():
  28. # name = k[7:] # remove `module.`
  29. # new_state_dict[name] = v
  30. # load params
  31. #model.load_state_dict(new_state_dict)
  32. #model = model.to(device)
  33. #model.load_state_dict(new_state_dict)
  34. return model
  35. if __name__=='__main__':
  36. parser = argparse.ArgumentParser()
  37. parser.add_argument('--weights', type=str, default='english_g2.onnx', help='model path(s)')
  38. parser.add_argument('--mWidth', type=int, default=640, help='segmodel mWdith')
  39. parser.add_argument('--mHeight', type=int, default=360, help='segmodel mHeight')
  40. opt = parser.parse_args()
  41. pthmodel = crnnModel(opt)
  42. ###转换TRT模型
  43. onnxFile=opt.weights.replace('.pth','_%dX%d.onnx'%(opt.mWidth,opt.mHeight))
  44. trtFile=opt.weights.replace('.pth','_%dX%d.engine'%(opt.mWidth,opt.mHeight))
  45. print('#'*20, ' begin to toONNX')
  46. toONNX(pthmodel,onnxFile,inputShape=(1,1,opt.mHeight, opt.mWidth),device='cuda:0')
  47. print('#'*20, ' begin to TRT')
  48. ONNXtoTrt(onnxFile,trtFile,half=False)