ソースを参照

YOLOv5 Export Benchmarks for GPU (#6963)

* Add benchmarks.py GPU support

* Updates

* Updates

* Updates

* Updates

* Add --half

* Add TRT requirements

* Cleanup

* Add TF to warmup types

* Update export.py

* Update export.py

* Update benchmarks.py
modifyDataloader
Glenn Jocher GitHub 2年前
コミット
932dc78496
この署名に対応する既知のキーがデータベースに存在しません GPGキーID: 4AEE18F83AFDEB23
3個のファイルの変更31行の追加18行の削除
  1. +12
    -12
      export.py
  2. +4
    -3
      models/common.py
  3. +15
    -3
      utils/benchmarks.py

+ 12
- 12
export.py ファイルの表示

@@ -75,18 +75,18 @@ from utils.torch_utils import select_device

def export_formats():
# YOLOv5 export formats
x = [['PyTorch', '-', '.pt'],
['TorchScript', 'torchscript', '.torchscript'],
['ONNX', 'onnx', '.onnx'],
['OpenVINO', 'openvino', '_openvino_model'],
['TensorRT', 'engine', '.engine'],
['CoreML', 'coreml', '.mlmodel'],
['TensorFlow SavedModel', 'saved_model', '_saved_model'],
['TensorFlow GraphDef', 'pb', '.pb'],
['TensorFlow Lite', 'tflite', '.tflite'],
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite'],
['TensorFlow.js', 'tfjs', '_web_model']]
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix'])
x = [['PyTorch', '-', '.pt', True],
['TorchScript', 'torchscript', '.torchscript', True],
['ONNX', 'onnx', '.onnx', True],
['OpenVINO', 'openvino', '_openvino_model', False],
['TensorRT', 'engine', '.engine', True],
['CoreML', 'coreml', '.mlmodel', False],
['TensorFlow SavedModel', 'saved_model', '_saved_model', True],
['TensorFlow GraphDef', 'pb', '.pb', True],
['TensorFlow Lite', 'tflite', '.tflite', False],
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False],
['TensorFlow.js', 'tfjs', '_web_model', False]]
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'GPU'])


def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):

+ 4
- 3
models/common.py ファイルの表示

@@ -464,10 +464,11 @@ class DetectMultiBackend(nn.Module):

def warmup(self, imgsz=(1, 3, 640, 640)):
# Warmup model by running inference once
if self.pt or self.jit or self.onnx or self.engine: # warmup types
if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models
if any((self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb)): # warmup types
if self.device.type != 'cpu': # only warmup GPU models
im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
self.forward(im) # warmup
for _ in range(2 if self.jit else 1): #
self.forward(im) # warmup

@staticmethod
def model_type(p='path/to/model.pt'):

+ 15
- 3
utils/benchmarks.py ファイルの表示

@@ -19,6 +19,7 @@ TensorFlow.js | `tfjs` | yolov5s_web_model/
Requirements:
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
$ pip install -U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com # TensorRT

Usage:
$ python utils/benchmarks.py --weights yolov5s.pt --img 640
@@ -41,20 +42,29 @@ import export
import val
from utils import notebook_init
from utils.general import LOGGER, print_args
from utils.torch_utils import select_device


def run(weights=ROOT / 'yolov5s.pt', # weights path
imgsz=640, # inference size (pixels)
batch_size=1, # batch size
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
half=False, # use FP16 half-precision inference
):
y, t = [], time.time()
formats = export.export_formats()
for i, (name, f, suffix) in formats.iterrows(): # index, (name, file, suffix)
device = select_device(device)
for i, (name, f, suffix, gpu) in formats.iterrows(): # index, (name, file, suffix, gpu-capable)
try:
w = weights if f == '-' else export.run(weights=weights, imgsz=[imgsz], include=[f], device='cpu')[-1]
if device.type != 'cpu':
assert gpu, f'{name} inference not supported on GPU'
if f == '-':
w = weights # PyTorch format
else:
w = export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # all others
assert suffix in str(w), 'export failed'
result = val.run(data, w, batch_size, imgsz=imgsz, plots=False, device='cpu', task='benchmark')
result = val.run(data, w, batch_size, imgsz, plots=False, device=device, task='benchmark', half=half)
metrics = result[0] # metrics (mp, mr, map50, map, *losses(box, obj, cls))
speeds = result[2] # times (preprocess, inference, postprocess)
y.append([name, metrics[3], speeds[1]]) # mAP, t_inference
@@ -78,6 +88,8 @@ def parse_opt():
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
opt = parser.parse_args()
print_args(FILE.stem, opt)
return opt

読み込み中…
キャンセル
保存