Browse Source

TorchScript single-output fix (#7261)

modifyDataloader
Glenn Jocher GitHub 2 years ago
parent
commit
8bc839ed8e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 9 deletions
  1. +12
    -6
      export.py
  2. +4
    -3
      models/common.py

+ 12
- 6
export.py View File

@@ -73,12 +73,18 @@ from utils.torch_utils import select_device

def export_formats():
# YOLOv5 export formats
x = [['PyTorch', '-', '.pt', True], ['TorchScript', 'torchscript', '.torchscript', True],
['ONNX', 'onnx', '.onnx', True], ['OpenVINO', 'openvino', '_openvino_model', False],
['TensorRT', 'engine', '.engine', True], ['CoreML', 'coreml', '.mlmodel', False],
['TensorFlow SavedModel', 'saved_model', '_saved_model', True], ['TensorFlow GraphDef', 'pb', '.pb', True],
['TensorFlow Lite', 'tflite', '.tflite', False], ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False],
['TensorFlow.js', 'tfjs', '_web_model', False]]
x = [
['PyTorch', '-', '.pt', True],
['TorchScript', 'torchscript', '.torchscript', True],
['ONNX', 'onnx', '.onnx', True],
['OpenVINO', 'openvino', '_openvino_model', False],
['TensorRT', 'engine', '.engine', True],
['CoreML', 'coreml', '.mlmodel', False],
['TensorFlow SavedModel', 'saved_model', '_saved_model', True],
['TensorFlow GraphDef', 'pb', '.pb', True],
['TensorFlow Lite', 'tflite', '.tflite', False],
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False],
['TensorFlow.js', 'tfjs', '_web_model', False],]
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'GPU'])



+ 4
- 3
models/common.py View File

@@ -406,9 +406,10 @@ class DetectMultiBackend(nn.Module):
def forward(self, im, augment=False, visualize=False, val=False):
# YOLOv5 MultiBackend inference
b, ch, h, w = im.shape # batch, channel, height, width
if self.pt or self.jit: # PyTorch
y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize)
return y if val else y[0]
if self.pt: # PyTorch
y = self.model(im, augment=augment, visualize=visualize)[0]
elif self.jit: # TorchScript
y = self.model(im)[0]
elif self.dnn: # ONNX OpenCV DNN
im = im.cpu().numpy() # torch to numpy
self.net.setInput(im)

Loading…
Cancel
Save