Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

38 lines
1.3KB

  1. """Exports a pytorch *.pt model to *.torchscript format
  2. Usage:
  3. $ export PYTHONPATH="$PWD" && python models/torchscript_export.py --weights ./weights/yolov5s.pt --img 640 --batch 1
  4. """
  5. import argparse
  6. from models.common import *
  7. from utils import google_utils
  8. if __name__ == '__main__':
  9. parser = argparse.ArgumentParser()
  10. parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
  11. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size')
  12. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  13. opt = parser.parse_args()
  14. print(opt)
  15. # Parameters
  16. f = opt.weights.replace('.pt', '.torchscript') # onnx filename
  17. img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size, (1, 3, 320, 192) iDetection
  18. # Load pytorch model
  19. google_utils.attempt_download(opt.weights)
  20. model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float()
  21. model.eval()
  22. # Don't fuse layers, it won't work with torchscript exports
  23. #model.fuse()
  24. # Export to jit/torchscript
  25. model.model[-1].export = True # set Detect() layer export=True
  26. _ = model(img) # dry run
  27. traced_script_module = torch.jit.trace(model, img)
  28. traced_script_module.save(f)