You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

56 lines
2.1KB

  1. """Exports a YOLOv5 *.pt model to *.onnx and *.torchscript formats
  2. Usage:
  3. $ export PYTHONPATH="$PWD" && python models/export.py --weights ./weights/yolov5s.pt --img 640 --batch 1
  4. """
  5. import argparse
  6. import onnx
  7. from models.common import *
  8. from utils import google_utils
  9. if __name__ == '__main__':
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
  12. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size')
  13. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  14. opt = parser.parse_args()
  15. opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
  16. print(opt)
  17. # Input
  18. img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size, (1, 3, 320, 192) iDetection
  19. # Load PyTorch model
  20. google_utils.attempt_download(opt.weights)
  21. model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float()
  22. model.eval()
  23. model.model[-1].export = True # set Detect() layer export=True
  24. _ = model(img) # dry run
  25. # Export to torchscript
  26. try:
  27. f = opt.weights.replace('.pt', '.torchscript') # filename
  28. ts = torch.jit.trace(model, img)
  29. ts.save(f)
  30. print('Torchscript export success, saved as %s' % f)
  31. except:
  32. print('Torchscript export failed.')
  33. # Export to ONNX
  34. try:
  35. f = opt.weights.replace('.pt', '.onnx') # filename
  36. model.fuse() # only for ONNX
  37. torch.onnx.export(model, img, f, verbose=False, opset_version=11, input_names=['images'],
  38. output_names=['output']) # output_names=['classes', 'boxes']
  39. # Checks
  40. onnx_model = onnx.load(f) # load onnx model
  41. onnx.checker.check_model(onnx_model) # check onnx model
  42. print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable representation of the graph
  43. print('ONNX export success, saved as %s\nView with https://github.com/lutzroeder/netron' % f)
  44. except:
  45. print('ONNX export failed.')