Quellcode durchsuchen

Add EdgeTPU support (#3630)

* Add models/tf.py for TensorFlow and TFLite export

* Set auto=False for int8 calibration

* Update requirements.txt for TensorFlow and TFLite export

* Read anchors directly from PyTorch weights

* Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export

* Remove check_anchor_order, check_file, set_logging from import

* Reformat code and optimize imports

* Autodownload model and check cfg

* update --source path, img-size to 320, single output

* Adjust representative_dataset

* Put representative dataset in tfl_int8 block

* detect.py TF inference

* weights to string

* weights to string

* cleanup tf.py

* Add --dynamic-batch-size

* Add xywh normalization to reduce calibration error

* Update requirements.txt

TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error

* Fix imports

Move C3 from models.experimental to models.common

* Add models/tf.py for TensorFlow and TFLite export

* Set auto=False for int8 calibration

* Update requirements.txt for TensorFlow and TFLite export

* Read anchors directly from PyTorch weights

* Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export

* Remove check_anchor_order, check_file, set_logging from import

* Reformat code and optimize imports

* Autodownload model and check cfg

* update --source path, img-size to 320, single output

* Adjust representative_dataset

* detect.py TF inference

* Put representative dataset in tfl_int8 block

* weights to string

* weights to string

* cleanup tf.py

* Add --dynamic-batch-size

* Add xywh normalization to reduce calibration error

* Update requirements.txt

TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error

* Fix imports

Move C3 from models.experimental to models.common

* implement C3() and SiLU()

* Add TensorFlow and TFLite Detection

* Add --tfl-detect for TFLite Detection

* Add int8 quantized TFLite inference in detect.py

* Add --edgetpu for Edge TPU detection

* Fix --img-size to add rectangle TensorFlow and TFLite input

* Add --no-tf-nms to detect objects using models combined with TensorFlow NMS

* Fix --img-size list type input

* Update README.md

* Add Android project for TFLite inference

* Upgrade TensorFlow v2.3.1 -> v2.4.0

* Disable normalization of xywh

* Rewrite names init in detect.py

* Change input resolution 640 -> 320 on Android

* Disable NNAPI

* Update README.me --img 640 -> 320

* Update README.me for Edge TPU

* Update README.md

* Fix reshape dim to support dynamic batching

* Fix reshape dim to support dynamic batching

* Add epsilon argument in tf_BN, which is different between TF and PT

* Set stride to None if not using PyTorch, and do not warmup without PyTorch

* Add list support in check_img_size()

* Add list input support in detect.py

* sys.path.append('./') to run from yolov5/

* Add int8 quantization support for TensorFlow 2.5

* Add get_coco128.sh

* Remove --no-tfl-detect in models/tf.py (Use tf-android-tfl-detect branch for EdgeTPU)

* Update requirements.txt

* Replace torch.load() with attempt_load()

* Update requirements.txt

* Add --tf-raw-resize to set half_pixel_centers=False

* Remove android directory

* Update README.md

* Update README.md

* Add multiple OS support for EdgeTPU detection

* Fix export and detect

* Export 3 YOLO heads with Edge TPU models

* Remove xywh denormalization with Edge TPU models in detect.py

* Fix saved_model and pb detect error

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix pre-commit.ci failure

* Add edgetpu in export.py docstring

* Fix Edge TPU model detection exported by TF 2.7

* Add class names for TF/TFLite in DetectMultibackend

* Fix assignment with nl in TFLite Detection

* Add check when getting Edge TPU compiler version

* Add UTF-8 encoding in opening --data file for Windows

* Remove redundant TensorFlow import

* Add Edge TPU in export.py's docstring

* Add the detect layer in Edge TPU model conversion

* Default `dnn=False`

* Cleanup data.yaml loading

* Update detect.py

* Update val.py

* Comments and generalize data.yaml names

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: unknown <fangjiacong@ut.cn>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
modifyDataloader
Jiacong Fang GitHub vor 2 Jahren
Ursprung
Commit
d95978a562
Es konnte kein GPG-Schlüssel zu dieser Signatur gefunden werden GPG-Schlüssel-ID: 4AEE18F83AFDEB23
4 geänderte Dateien mit 37 neuen und 8 gelöschten Zeilen
  1. +3
    -1
      detect.py
  2. +25
    -4
      export.py
  3. +8
    -2
      models/common.py
  4. +1
    -1
      val.py

+ 3
- 1
detect.py Datei anzeigen

@@ -38,6 +38,7 @@ from utils.torch_utils import select_device, time_sync
@torch.no_grad()
def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
imgsz=(640, 640), # inference size (height, width)
conf_thres=0.25, # confidence threshold
iou_thres=0.45, # NMS IOU threshold
@@ -76,7 +77,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)

# Load model
device = select_device(device)
model = DetectMultiBackend(weights, device=device, dnn=dnn)
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data)
stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine
imgsz = check_img_size(imgsz, s=stride) # check image size

@@ -204,6 +205,7 @@ def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path(s)')
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam')
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')

+ 25
- 4
export.py Datei anzeigen

@@ -248,6 +248,24 @@ def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('Te
LOGGER.info(f'\n{prefix} export failure: {e}')


def export_edgetpu(keras_model, im, file, prefix=colorstr('Edge TPU:')):
# YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
try:
cmd = 'edgetpu_compiler --version'
out = subprocess.run(cmd, shell=True, capture_output=True, check=True)
ver = out.stdout.decode().split()[-1]
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
f = str(file).replace('.pt', '-int8_edgetpu.tflite')
f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model

cmd = f"edgetpu_compiler -s {f_tfl}"
subprocess.run(cmd, shell=True, check=True)

LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')


def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
# YOLOv5 TensorFlow.js export
try:
@@ -285,6 +303,7 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):


def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
try:
check_requirements(('tensorrt',))
import tensorrt as trt
@@ -356,7 +375,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
):
t = time.time()
include = [x.lower() for x in include]
tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'tfjs')) # TensorFlow exports
tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs')) # TensorFlow exports
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)

# Checks
@@ -405,15 +424,17 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'

# TensorFlow Exports
if any(tf_exports):
pb, tflite, tfjs = tf_exports[1:]
pb, tflite, edgetpu, tfjs = tf_exports[1:]
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
model = export_saved_model(model, im, file, dynamic, tf_nms=nms or agnostic_nms or tfjs,
agnostic_nms=agnostic_nms or tfjs, topk_per_class=topk_per_class, topk_all=topk_all,
conf_thres=conf_thres, iou_thres=iou_thres) # keras model
if pb or tfjs: # pb prerequisite to tfjs
export_pb(model, im, file)
if tflite:
export_tflite(model, im, file, int8=int8, data=data, ncalib=100)
if tflite or edgetpu:
export_tflite(model, im, file, int8=int8 or edgetpu, data=data, ncalib=100)
if edgetpu:
export_edgetpu(model, im, file)
if tfjs:
export_tfjs(model, im, file)


+ 8
- 2
models/common.py Datei anzeigen

@@ -17,6 +17,7 @@ import pandas as pd
import requests
import torch
import torch.nn as nn
import yaml
from PIL import Image
from torch.cuda import amp

@@ -276,7 +277,7 @@ class Concat(nn.Module):

class DetectMultiBackend(nn.Module):
# YOLOv5 MultiBackend class for python inference on various backends
def __init__(self, weights='yolov5s.pt', device=None, dnn=False):
def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
# Usage:
# PyTorch: weights = *.pt
# TorchScript: *.torchscript
@@ -284,6 +285,7 @@ class DetectMultiBackend(nn.Module):
# TensorFlow: *_saved_model
# TensorFlow: *.pb
# TensorFlow Lite: *.tflite
# TensorFlow Edge TPU: *_edgetpu.tflite
# ONNX Runtime: *.onnx
# OpenCV DNN: *.onnx with dnn=True
# TensorRT: *.engine
@@ -297,6 +299,9 @@ class DetectMultiBackend(nn.Module):
pt, jit, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
w = attempt_download(w) # download if not local
if data: # data.yaml path (optional)
with open(data, errors='ignore') as f:
names = yaml.safe_load(f)['names'] # class names

if jit: # TorchScript
LOGGER.info(f'Loading {w} for TorchScript inference...')
@@ -343,7 +348,7 @@ class DetectMultiBackend(nn.Module):
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
context = model.create_execution_context()
batch_size = bindings['images'].shape[0]
else: # TensorFlow model (TFLite, pb, saved_model)
else: # TensorFlow (TFLite, pb, saved_model)
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...')
import tensorflow as tf
@@ -425,6 +430,7 @@ class DetectMultiBackend(nn.Module):
y[..., 1] *= h # y
y[..., 2] *= w # w
y[..., 3] *= h # h

y = torch.tensor(y) if isinstance(y, np.ndarray) else y
return (y, []) if val else y


+ 1
- 1
val.py Datei anzeigen

@@ -124,7 +124,7 @@ def run(data,
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir

# Load model
model = DetectMultiBackend(weights, device=device, dnn=dnn)
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data)
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
imgsz = check_img_size(imgsz, s=stride) # check image size
half &= (pt or jit or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA

Laden…
Abbrechen
Speichern