You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

150 lines
6.2KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Run YOLOv5 benchmarks on all supported export formats
  4. Format | `export.py --include` | Model
  5. --- | --- | ---
  6. PyTorch | - | yolov5s.pt
  7. TorchScript | `torchscript` | yolov5s.torchscript
  8. ONNX | `onnx` | yolov5s.onnx
  9. OpenVINO | `openvino` | yolov5s_openvino_model/
  10. TensorRT | `engine` | yolov5s.engine
  11. CoreML | `coreml` | yolov5s.mlmodel
  12. TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/
  13. TensorFlow GraphDef | `pb` | yolov5s.pb
  14. TensorFlow Lite | `tflite` | yolov5s.tflite
  15. TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite
  16. TensorFlow.js | `tfjs` | yolov5s_web_model/
  17. Requirements:
  18. $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
  19. $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
  20. $ pip install -U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com # TensorRT
  21. Usage:
  22. $ python utils/benchmarks.py --weights yolov5s.pt --img 640
  23. """
  24. import argparse
  25. import sys
  26. import time
  27. from pathlib import Path
  28. import pandas as pd
  29. FILE = Path(__file__).resolve()
  30. ROOT = FILE.parents[1] # YOLOv5 root directory
  31. if str(ROOT) not in sys.path:
  32. sys.path.append(str(ROOT)) # add ROOT to PATH
  33. # ROOT = ROOT.relative_to(Path.cwd()) # relative
  34. import export
  35. import val
  36. from utils import notebook_init
  37. from utils.general import LOGGER, check_yaml, file_size, print_args
  38. from utils.torch_utils import select_device
  39. def run(
  40. weights=ROOT / 'yolov5s.pt', # weights path
  41. imgsz=640, # inference size (pixels)
  42. batch_size=1, # batch size
  43. data=ROOT / 'data/coco128.yaml', # dataset.yaml path
  44. device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
  45. half=False, # use FP16 half-precision inference
  46. test=False, # test exports only
  47. pt_only=False, # test PyTorch only
  48. ):
  49. y, t = [], time.time()
  50. device = select_device(device)
  51. for i, (name, f, suffix, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, gpu-capable)
  52. try:
  53. assert i != 9, 'Edge TPU not supported'
  54. assert i != 10, 'TF.js not supported'
  55. if device.type != 'cpu':
  56. assert gpu, f'{name} inference not supported on GPU'
  57. # Export
  58. if f == '-':
  59. w = weights # PyTorch format
  60. else:
  61. w = export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # all others
  62. assert suffix in str(w), 'export failed'
  63. # Validate
  64. result = val.run(data, w, batch_size, imgsz, plots=False, device=device, task='benchmark', half=half)
  65. metrics = result[0] # metrics (mp, mr, map50, map, *losses(box, obj, cls))
  66. speeds = result[2] # times (preprocess, inference, postprocess)
  67. y.append([name, round(file_size(w), 1), round(metrics[3], 4), round(speeds[1], 2)]) # MB, mAP, t_inference
  68. except Exception as e:
  69. LOGGER.warning(f'WARNING: Benchmark failure for {name}: {e}')
  70. y.append([name, None, None, None]) # mAP, t_inference
  71. if pt_only and i == 0:
  72. break # break after PyTorch
  73. # Print results
  74. LOGGER.info('\n')
  75. parse_opt()
  76. notebook_init() # print system info
  77. c = ['Format', 'Size (MB)', 'mAP@0.5:0.95', 'Inference time (ms)'] if map else ['Format', 'Export', '', '']
  78. py = pd.DataFrame(y, columns=c)
  79. LOGGER.info(f'\nBenchmarks complete ({time.time() - t:.2f}s)')
  80. LOGGER.info(str(py if map else py.iloc[:, :2]))
  81. return py
  82. def test(
  83. weights=ROOT / 'yolov5s.pt', # weights path
  84. imgsz=640, # inference size (pixels)
  85. batch_size=1, # batch size
  86. data=ROOT / 'data/coco128.yaml', # dataset.yaml path
  87. device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
  88. half=False, # use FP16 half-precision inference
  89. test=False, # test exports only
  90. pt_only=False, # test PyTorch only
  91. ):
  92. y, t = [], time.time()
  93. device = select_device(device)
  94. for i, (name, f, suffix, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, gpu-capable)
  95. try:
  96. w = weights if f == '-' else \
  97. export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # weights
  98. assert suffix in str(w), 'export failed'
  99. y.append([name, True])
  100. except Exception:
  101. y.append([name, False]) # mAP, t_inference
  102. # Print results
  103. LOGGER.info('\n')
  104. parse_opt()
  105. notebook_init() # print system info
  106. py = pd.DataFrame(y, columns=['Format', 'Export'])
  107. LOGGER.info(f'\nExports complete ({time.time() - t:.2f}s)')
  108. LOGGER.info(str(py))
  109. return py
  110. def parse_opt():
  111. parser = argparse.ArgumentParser()
  112. parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='weights path')
  113. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
  114. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  115. parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
  116. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  117. parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
  118. parser.add_argument('--test', action='store_true', help='test exports only')
  119. parser.add_argument('--pt-only', action='store_true', help='test PyTorch only')
  120. opt = parser.parse_args()
  121. opt.data = check_yaml(opt.data) # check YAML
  122. print_args(vars(opt))
  123. return opt
  124. def main(opt):
  125. test(**vars(opt)) if opt.test else run(**vars(opt))
  126. if __name__ == "__main__":
  127. opt = parse_opt()
  128. main(opt)