* Add models/tf.py for TensorFlow and TFLite export * Set auto=False for int8 calibration * Update requirements.txt for TensorFlow and TFLite export * Read anchors directly from PyTorch weights * Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export * Remove check_anchor_order, check_file, set_logging from import * Reformat code and optimize imports * Autodownload model and check cfg * update --source path, img-size to 320, single output * Adjust representative_dataset * Put representative dataset in tfl_int8 block * detect.py TF inference * weights to string * weights to string * cleanup tf.py * Add --dynamic-batch-size * Add xywh normalization to reduce calibration error * Update requirements.txt TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error * Fix imports Move C3 from models.experimental to models.common * Add models/tf.py for TensorFlow and TFLite export * Set auto=False for int8 calibration * Update requirements.txt for TensorFlow and TFLite export * Read anchors directly from PyTorch weights * Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export * Remove check_anchor_order, check_file, set_logging from import * Reformat code and optimize imports * Autodownload model and check cfg * update --source path, img-size to 320, single output * Adjust representative_dataset * detect.py TF inference * Put representative dataset in tfl_int8 block * weights to string * weights to string * cleanup tf.py * Add --dynamic-batch-size * Add xywh normalization to reduce calibration error * Update requirements.txt TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error * Fix imports Move C3 from models.experimental to models.common * implement C3() and SiLU() * Add TensorFlow and TFLite Detection * Add --tfl-detect for TFLite Detection * Add int8 quantized TFLite inference in detect.py * Add --edgetpu for Edge TPU detection * Fix --img-size to add rectangle TensorFlow and TFLite input * Add --no-tf-nms to detect objects using models combined with TensorFlow NMS * Fix --img-size list type input * Update README.md * Add Android project for TFLite inference * Upgrade TensorFlow v2.3.1 -> v2.4.0 * Disable normalization of xywh * Rewrite names init in detect.py * Change input resolution 640 -> 320 on Android * Disable NNAPI * Update README.me --img 640 -> 320 * Update README.me for Edge TPU * Update README.md * Fix reshape dim to support dynamic batching * Fix reshape dim to support dynamic batching * Add epsilon argument in tf_BN, which is different between TF and PT * Set stride to None if not using PyTorch, and do not warmup without PyTorch * Add list support in check_img_size() * Add list input support in detect.py * sys.path.append('./') to run from yolov5/ * Add int8 quantization support for TensorFlow 2.5 * Add get_coco128.sh * Remove --no-tfl-detect in models/tf.py (Use tf-android-tfl-detect branch for EdgeTPU) * Update requirements.txt * Replace torch.load() with attempt_load() * Update requirements.txt * Add --tf-raw-resize to set half_pixel_centers=False * Remove android directory * Update README.md * Update README.md * Add multiple OS support for EdgeTPU detection * Fix export and detect * Export 3 YOLO heads with Edge TPU models * Remove xywh denormalization with Edge TPU models in detect.py * Fix saved_model and pb detect error * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix pre-commit.ci failure * Add edgetpu in export.py docstring * Fix Edge TPU model detection exported by TF 2.7 * Add class names for TF/TFLite in DetectMultibackend * Fix assignment with nl in TFLite Detection * Add check when getting Edge TPU compiler version * Add UTF-8 encoding in opening --data file for Windows * Remove redundant TensorFlow import * Add Edge TPU in export.py's docstring * Add the detect layer in Edge TPU model conversion * Default `dnn=False` * Cleanup data.yaml loading * Update detect.py * Update val.py * Comments and generalize data.yaml names Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: unknown <fangjiacong@ut.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>modifyDataloader
@torch.no_grad() | @torch.no_grad() | ||||
def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) | def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) | ||||
source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam | source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam | ||||
data=ROOT / 'data/coco128.yaml', # dataset.yaml path | |||||
imgsz=(640, 640), # inference size (height, width) | imgsz=(640, 640), # inference size (height, width) | ||||
conf_thres=0.25, # confidence threshold | conf_thres=0.25, # confidence threshold | ||||
iou_thres=0.45, # NMS IOU threshold | iou_thres=0.45, # NMS IOU threshold | ||||
# Load model | # Load model | ||||
device = select_device(device) | device = select_device(device) | ||||
model = DetectMultiBackend(weights, device=device, dnn=dnn) | |||||
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data) | |||||
stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine | stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine | ||||
imgsz = check_img_size(imgsz, s=stride) # check image size | imgsz = check_img_size(imgsz, s=stride) # check image size | ||||
parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||||
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path(s)') | parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path(s)') | ||||
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam') | parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam') | ||||
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path') | |||||
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w') | parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w') | ||||
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold') | parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold') | ||||
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold') | parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold') |
LOGGER.info(f'\n{prefix} export failure: {e}') | LOGGER.info(f'\n{prefix} export failure: {e}') | ||||
def export_edgetpu(keras_model, im, file, prefix=colorstr('Edge TPU:')): | |||||
# YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/ | |||||
try: | |||||
cmd = 'edgetpu_compiler --version' | |||||
out = subprocess.run(cmd, shell=True, capture_output=True, check=True) | |||||
ver = out.stdout.decode().split()[-1] | |||||
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...') | |||||
f = str(file).replace('.pt', '-int8_edgetpu.tflite') | |||||
f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model | |||||
cmd = f"edgetpu_compiler -s {f_tfl}" | |||||
subprocess.run(cmd, shell=True, check=True) | |||||
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||||
except Exception as e: | |||||
LOGGER.info(f'\n{prefix} export failure: {e}') | |||||
def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')): | def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')): | ||||
# YOLOv5 TensorFlow.js export | # YOLOv5 TensorFlow.js export | ||||
try: | try: | ||||
def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): | def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): | ||||
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt | |||||
try: | try: | ||||
check_requirements(('tensorrt',)) | check_requirements(('tensorrt',)) | ||||
import tensorrt as trt | import tensorrt as trt | ||||
): | ): | ||||
t = time.time() | t = time.time() | ||||
include = [x.lower() for x in include] | include = [x.lower() for x in include] | ||||
tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'tfjs')) # TensorFlow exports | |||||
tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs')) # TensorFlow exports | |||||
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) | file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) | ||||
# Checks | # Checks | ||||
# TensorFlow Exports | # TensorFlow Exports | ||||
if any(tf_exports): | if any(tf_exports): | ||||
pb, tflite, tfjs = tf_exports[1:] | |||||
pb, tflite, edgetpu, tfjs = tf_exports[1:] | |||||
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.' | assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.' | ||||
model = export_saved_model(model, im, file, dynamic, tf_nms=nms or agnostic_nms or tfjs, | model = export_saved_model(model, im, file, dynamic, tf_nms=nms or agnostic_nms or tfjs, | ||||
agnostic_nms=agnostic_nms or tfjs, topk_per_class=topk_per_class, topk_all=topk_all, | agnostic_nms=agnostic_nms or tfjs, topk_per_class=topk_per_class, topk_all=topk_all, | ||||
conf_thres=conf_thres, iou_thres=iou_thres) # keras model | conf_thres=conf_thres, iou_thres=iou_thres) # keras model | ||||
if pb or tfjs: # pb prerequisite to tfjs | if pb or tfjs: # pb prerequisite to tfjs | ||||
export_pb(model, im, file) | export_pb(model, im, file) | ||||
if tflite: | |||||
export_tflite(model, im, file, int8=int8, data=data, ncalib=100) | |||||
if tflite or edgetpu: | |||||
export_tflite(model, im, file, int8=int8 or edgetpu, data=data, ncalib=100) | |||||
if edgetpu: | |||||
export_edgetpu(model, im, file) | |||||
if tfjs: | if tfjs: | ||||
export_tfjs(model, im, file) | export_tfjs(model, im, file) | ||||
import requests | import requests | ||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import yaml | |||||
from PIL import Image | from PIL import Image | ||||
from torch.cuda import amp | from torch.cuda import amp | ||||
class DetectMultiBackend(nn.Module): | class DetectMultiBackend(nn.Module): | ||||
# YOLOv5 MultiBackend class for python inference on various backends | # YOLOv5 MultiBackend class for python inference on various backends | ||||
def __init__(self, weights='yolov5s.pt', device=None, dnn=False): | |||||
def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None): | |||||
# Usage: | # Usage: | ||||
# PyTorch: weights = *.pt | # PyTorch: weights = *.pt | ||||
# TorchScript: *.torchscript | # TorchScript: *.torchscript | ||||
# TensorFlow: *_saved_model | # TensorFlow: *_saved_model | ||||
# TensorFlow: *.pb | # TensorFlow: *.pb | ||||
# TensorFlow Lite: *.tflite | # TensorFlow Lite: *.tflite | ||||
# TensorFlow Edge TPU: *_edgetpu.tflite | |||||
# ONNX Runtime: *.onnx | # ONNX Runtime: *.onnx | ||||
# OpenCV DNN: *.onnx with dnn=True | # OpenCV DNN: *.onnx with dnn=True | ||||
# TensorRT: *.engine | # TensorRT: *.engine | ||||
pt, jit, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans | pt, jit, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans | ||||
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults | stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults | ||||
w = attempt_download(w) # download if not local | w = attempt_download(w) # download if not local | ||||
if data: # data.yaml path (optional) | |||||
with open(data, errors='ignore') as f: | |||||
names = yaml.safe_load(f)['names'] # class names | |||||
if jit: # TorchScript | if jit: # TorchScript | ||||
LOGGER.info(f'Loading {w} for TorchScript inference...') | LOGGER.info(f'Loading {w} for TorchScript inference...') | ||||
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) | binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) | ||||
context = model.create_execution_context() | context = model.create_execution_context() | ||||
batch_size = bindings['images'].shape[0] | batch_size = bindings['images'].shape[0] | ||||
else: # TensorFlow model (TFLite, pb, saved_model) | |||||
else: # TensorFlow (TFLite, pb, saved_model) | |||||
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt | if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt | ||||
LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...') | LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...') | ||||
import tensorflow as tf | import tensorflow as tf | ||||
y[..., 1] *= h # y | y[..., 1] *= h # y | ||||
y[..., 2] *= w # w | y[..., 2] *= w # w | ||||
y[..., 3] *= h # h | y[..., 3] *= h # h | ||||
y = torch.tensor(y) if isinstance(y, np.ndarray) else y | y = torch.tensor(y) if isinstance(y, np.ndarray) else y | ||||
return (y, []) if val else y | return (y, []) if val else y | ||||
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir | (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir | ||||
# Load model | # Load model | ||||
model = DetectMultiBackend(weights, device=device, dnn=dnn) | |||||
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data) | |||||
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine | stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine | ||||
imgsz = check_img_size(imgsz, s=stride) # check image size | imgsz = check_img_size(imgsz, s=stride) # check image size | ||||
half &= (pt or jit or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA | half &= (pt or jit or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA |