Ver código fonte

Add OpenCV DNN option for ONNX inference (#5136)

* Add OpenCV DNN option for ONNX inference

Usage:

```bash
python detect.py --weights yolov5s.onnx  # ONNX Runtime inference
python detect.py --weights yolov5s.onnx -dnn  # OpenCV DNN inference
```

* DNN prediction to tensor

* Update detect.py
modifyDataloader
Glenn Jocher GitHub 2 anos atrás
pai
commit
0bf24cf641
Nenhuma chave conhecida encontrada para esta assinatura no banco de dados ID da chave GPG: 4AEE18F83AFDEB23
1 arquivos alterados com 15 adições e 5 exclusões
  1. +15
    -5
      detect.py

+ 15
- 5
detect.py Ver arquivo

@@ -56,6 +56,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
hide_labels=False, # hide labels
hide_conf=False, # hide confidences
half=False, # use FP16 half-precision inference
dnn=False, # use OpenCV DNN for ONNX inference
):
source = str(source)
save_img = not nosave and not source.endswith('.txt') # save inference images
@@ -72,7 +73,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
half &= device.type != 'cpu' # half precision only supported on CUDA

# Load model
w = weights[0] if isinstance(weights, list) else weights
w = str(weights[0] if isinstance(weights, list) else weights)
classify, suffix, suffixes = False, Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '']
check_suffix(w, suffixes) # check weights have acceptable suffix
pt, onnx, tflite, pb, saved_model = (suffix == x for x in suffixes) # backend booleans
@@ -87,9 +88,13 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
modelc = load_classifier(name='resnet50', n=2) # initialize
modelc.load_state_dict(torch.load('resnet50.pt', map_location=device)['model']).to(device).eval()
elif onnx:
check_requirements(('onnx', 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
if dnn:
# check_requirements(('opencv-python>=4.5.4',))
net = cv2.dnn.readNetFromONNX(w)
else:
check_requirements(('onnx', 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
else: # TensorFlow models
check_requirements(('tensorflow>=2.4.1',))
import tensorflow as tf
@@ -145,7 +150,11 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
pred = model(img, augment=augment, visualize=visualize)[0]
elif onnx:
pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
if dnn:
net.setInput(img)
pred = torch.tensor(net.forward())
else:
pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
else: # tensorflow model (tflite, pb, saved_model)
imn = img.permute(0, 2, 3, 1).cpu().numpy() # image in numpy
if pb:
@@ -281,6 +290,7 @@ def parse_opt():
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
print_args(FILE.stem, opt)

Carregando…
Cancelar
Salvar