Browse Source

Add TensorFlow and TFLite export (#1127)

* Add models/tf.py for TensorFlow and TFLite export

* Set auto=False for int8 calibration

* Update requirements.txt for TensorFlow and TFLite export

* Read anchors directly from PyTorch weights

* Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export

* Remove check_anchor_order, check_file, set_logging from import

* Reformat code and optimize imports

* Autodownload model and check cfg

* update --source path, img-size to 320, single output

* Adjust representative_dataset

* Put representative dataset in tfl_int8 block

* detect.py TF inference

* weights to string

* weights to string

* cleanup tf.py

* Add --dynamic-batch-size

* Add xywh normalization to reduce calibration error

* Update requirements.txt

TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error

* Fix imports

Move C3 from models.experimental to models.common

* Add models/tf.py for TensorFlow and TFLite export

* Set auto=False for int8 calibration

* Update requirements.txt for TensorFlow and TFLite export

* Read anchors directly from PyTorch weights

* Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export

* Remove check_anchor_order, check_file, set_logging from import

* Reformat code and optimize imports

* Autodownload model and check cfg

* update --source path, img-size to 320, single output

* Adjust representative_dataset

* detect.py TF inference

* Put representative dataset in tfl_int8 block

* weights to string

* weights to string

* cleanup tf.py

* Add --dynamic-batch-size

* Add xywh normalization to reduce calibration error

* Update requirements.txt

TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error

* Fix imports

Move C3 from models.experimental to models.common

* implement C3() and SiLU()

* Fix reshape dim to support dynamic batching

* Add epsilon argument in tf_BN, which is different between TF and PT

* Set stride to None if not using PyTorch, and do not warmup without PyTorch

* Add list support in check_img_size()

* Add list input support in detect.py

* sys.path.append('./') to run from yolov5/

* Add int8 quantization support for TensorFlow 2.5

* Add get_coco128.sh

* Remove --no-tfl-detect in models/tf.py (Use tf-android-tfl-detect branch for EdgeTPU)

* Update requirements.txt

* Replace torch.load() with attempt_load()

* Update requirements.txt

* Add --tf-raw-resize to set half_pixel_centers=False

* Add --agnostic-nms for TF class-agnostic NMS

* Cleanup after merge

* Cleanup2 after merge

* Cleanup3 after merge

* Add tf.py docstring with credit and usage

* pb saved_model and tflite use only one model in detect.py

* Add use cases in docstring of tf.py

* Remove redundant `stride` definition

* Remove keras direct import

* Fix `check_requirements(('tensorflow>=2.4.1',))`

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
modifyDataloader
Jiacong Fang GitHub 3 years ago
parent
commit
808bcad3bb
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 626 additions and 17 deletions
  1. +54
    -10
      detect.py
  2. +6
    -2
      models/experimental.py
  3. +558
    -0
      models/tf.py
  4. +1
    -0
      requirements.txt
  5. +7
    -5
      utils/datasets.py

+ 54
- 10
detect.py View File

from pathlib import Path from pathlib import Path


import cv2 import cv2
import numpy as np
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn


hide_labels=False, # hide labels hide_labels=False, # hide labels
hide_conf=False, # hide confidences hide_conf=False, # hide confidences
half=False, # use FP16 half-precision inference half=False, # use FP16 half-precision inference
tfl_int8=False, # INT8 quantized TFLite model
): ):
save_img = not nosave and not source.endswith('.txt') # save inference images save_img = not nosave and not source.endswith('.txt') # save inference images
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
# Load model # Load model
w = weights[0] if isinstance(weights, list) else weights w = weights[0] if isinstance(weights, list) else weights
classify, suffix = False, Path(w).suffix.lower() classify, suffix = False, Path(w).suffix.lower()
pt, onnx, tflite, pb, graph_def = (suffix == x for x in ['.pt', '.onnx', '.tflite', '.pb', '']) # backend
pt, onnx, tflite, pb, saved_model = (suffix == x for x in ['.pt', '.onnx', '.tflite', '.pb', '']) # backend
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
if pt: if pt:
model = attempt_load(weights, map_location=device) # load FP32 model model = attempt_load(weights, map_location=device) # load FP32 model
check_requirements(('onnx', 'onnxruntime')) check_requirements(('onnx', 'onnxruntime'))
import onnxruntime import onnxruntime
session = onnxruntime.InferenceSession(w, None) session = onnxruntime.InferenceSession(w, None)
else: # TensorFlow models
check_requirements(('tensorflow>=2.4.1',))
import tensorflow as tf
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped import
return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
tf.nest.map_structure(x.graph.as_graph_element, outputs))

graph_def = tf.Graph().as_graph_def()
graph_def.ParseFromString(open(w, 'rb').read())
frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
elif saved_model:
model = tf.keras.models.load_model(w)
elif tflite:
interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
imgsz = check_img_size(imgsz, s=stride) # check image size imgsz = check_img_size(imgsz, s=stride) # check image size


# Dataloader # Dataloader
if webcam: if webcam:
view_img = check_imshow() view_img = check_imshow()
cudnn.benchmark = True # set True to speed up constant image size inference cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz, stride=stride)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
bs = len(dataset) # batch_size bs = len(dataset) # batch_size
else: else:
dataset = LoadImages(source, img_size=imgsz, stride=stride)
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
bs = 1 # batch_size bs = 1 # batch_size
vid_path, vid_writer = [None] * bs, [None] * bs vid_path, vid_writer = [None] * bs, [None] * bs


# Run inference # Run inference
if pt and device.type != 'cpu': if pt and device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters()))) # run once
t0 = time.time() t0 = time.time()
for path, img, im0s, vid_cap in dataset: for path, img, im0s, vid_cap in dataset:
if pt:
if onnx:
img = img.astype('float32')
else:
img = torch.from_numpy(img).to(device) img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # uint8 to fp16/32 img = img.half() if half else img.float() # uint8 to fp16/32
elif onnx:
img = img.astype('float32')
img /= 255.0 # 0 - 255 to 0.0 - 1.0
img = img / 255.0 # 0 - 255 to 0.0 - 1.0
if len(img.shape) == 3: if len(img.shape) == 3:
img = img[None] # expand for batch dim img = img[None] # expand for batch dim


pred = model(img, augment=augment, visualize=visualize)[0] pred = model(img, augment=augment, visualize=visualize)[0]
elif onnx: elif onnx:
pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img})) pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
else: # tensorflow model (tflite, pb, saved_model)
imn = img.permute(0, 2, 3, 1).cpu().numpy() # image in numpy
if pb:
pred = frozen_func(x=tf.constant(imn)).numpy()
elif saved_model:
pred = model(imn, training=False).numpy()
elif tflite:
if tfl_int8:
scale, zero_point = input_details[0]['quantization']
imn = (imn / scale + zero_point).astype(np.uint8)
interpreter.set_tensor(input_details[0]['index'], imn)
interpreter.invoke()
pred = interpreter.get_tensor(output_details[0]['index'])
if tfl_int8:
scale, zero_point = output_details[0]['quantization']
pred = (pred.astype(np.float32) - zero_point) * scale
pred[..., 0] *= imgsz[1] # x
pred[..., 1] *= imgsz[0] # y
pred[..., 2] *= imgsz[1] # w
pred[..., 3] *= imgsz[0] # h
pred = torch.tensor(pred)


# NMS # NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)


def parse_opt(): def parse_opt():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pb', help='model.pt path(s)')
parser.add_argument('--source', type=str, default='data/images', help='file/dir/URL/glob, 0 for webcam') parser.add_argument('--source', type=str, default='data/images', help='file/dir/URL/glob, 0 for webcam')
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold') parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold') parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image') parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels') parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences') parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--tfl-int8', action='store_true', help='INT8 quantized TFLite model')
opt = parser.parse_args() opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
return opt return opt





+ 6
- 2
models/experimental.py View File

return y, None # inference, train output return y, None # inference, train output




def attempt_load(weights, map_location=None, inplace=True):
def attempt_load(weights, map_location=None, inplace=True, fuse=True):
from models.yolo import Detect, Model from models.yolo import Detect, Model


# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
model = Ensemble() model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]: for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w), map_location=map_location) # load ckpt = torch.load(attempt_download(w), map_location=map_location) # load
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
if fuse:
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
else:
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # without layer fuse



# Compatibility updates # Compatibility updates
for m in model.modules(): for m in model.modules():

+ 558
- 0
models/tf.py View File

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
TensorFlow/Keras and TFLite versions of YOLOv5
Authored by https://github.com/zldrobit in PR https://github.com/ultralytics/yolov5/pull/1127

Usage:
$ python models/tf.py --weights yolov5s.pt --cfg yolov5s.yaml

Export int8 TFLite models:
$ python models/tf.py --weights yolov5s.pt --cfg models/yolov5s.yaml --tfl-int8 \
--source path/to/images/ --ncalib 100

Detection:
$ python detect.py --weights yolov5s.pb --img 320
$ python detect.py --weights yolov5s_saved_model --img 320
$ python detect.py --weights yolov5s-fp16.tflite --img 320
$ python detect.py --weights yolov5s-int8.tflite --img 320 --tfl-int8

For TensorFlow.js:
$ python models/tf.py --weights yolov5s.pt --cfg models/yolov5s.yaml --img 320 --tf-nms --agnostic-nms
$ pip install tensorflowjs
$ tensorflowjs_converter \
--input_format=tf_frozen_model \
--output_node_names='Identity,Identity_1,Identity_2,Identity_3' \
yolov5s.pb \
web_model
$ # Edit web_model/model.json to sort Identity* in ascending order
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
$ npm install
$ ln -s ../../yolov5/web_model public/web_model
$ npm start
"""

import argparse
import logging
import os
import sys
import traceback
from copy import deepcopy
from pathlib import Path

sys.path.append('./') # to run '$ python *.py' files in subdirectories

import numpy as np
import tensorflow as tf
import torch
import torch.nn as nn
import yaml
from tensorflow import keras
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, autopad, C3
from models.experimental import MixConv2d, CrossConv, attempt_load
from models.yolo import Detect
from utils.datasets import LoadImages
from utils.general import make_divisible, check_file, check_dataset

logger = logging.getLogger(__name__)


class tf_BN(keras.layers.Layer):
# TensorFlow BatchNormalization wrapper
def __init__(self, w=None):
super(tf_BN, self).__init__()
self.bn = keras.layers.BatchNormalization(
beta_initializer=keras.initializers.Constant(w.bias.numpy()),
gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),
moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),
epsilon=w.eps)

def call(self, inputs):
return self.bn(inputs)


class tf_Pad(keras.layers.Layer):
def __init__(self, pad):
super(tf_Pad, self).__init__()
self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])

def call(self, inputs):
return tf.pad(inputs, self.pad, mode='constant', constant_values=0)


class tf_Conv(keras.layers.Layer):
# Standard convolution
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
# ch_in, ch_out, weights, kernel, stride, padding, groups
super(tf_Conv, self).__init__()
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
assert isinstance(k, int), "Convolution with multiple kernels are not allowed."
# TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
# see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch

conv = keras.layers.Conv2D(
c2, k, s, 'SAME' if s == 1 else 'VALID', use_bias=False,
kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()))
self.conv = conv if s == 1 else keras.Sequential([tf_Pad(autopad(k, p)), conv])
self.bn = tf_BN(w.bn) if hasattr(w, 'bn') else tf.identity

# YOLOv5 activations
if isinstance(w.act, nn.LeakyReLU):
self.act = (lambda x: keras.activations.relu(x, alpha=0.1)) if act else tf.identity
elif isinstance(w.act, nn.Hardswish):
self.act = (lambda x: x * tf.nn.relu6(x + 3) * 0.166666667) if act else tf.identity
elif isinstance(w.act, nn.SiLU):
self.act = (lambda x: keras.activations.swish(x)) if act else tf.identity

def call(self, inputs):
return self.act(self.bn(self.conv(inputs)))


class tf_Focus(keras.layers.Layer):
# Focus wh information into c-space
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
# ch_in, ch_out, kernel, stride, padding, groups
super(tf_Focus, self).__init__()
self.conv = tf_Conv(c1 * 4, c2, k, s, p, g, act, w.conv)

def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)
# inputs = inputs / 255. # normalize 0-255 to 0-1
return self.conv(tf.concat([inputs[:, ::2, ::2, :],
inputs[:, 1::2, ::2, :],
inputs[:, ::2, 1::2, :],
inputs[:, 1::2, 1::2, :]], 3))


class tf_Bottleneck(keras.layers.Layer):
# Standard bottleneck
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, shortcut, groups, expansion
super(tf_Bottleneck, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = tf_Conv(c_, c2, 3, 1, g=g, w=w.cv2)
self.add = shortcut and c1 == c2

def call(self, inputs):
return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))


class tf_Conv2d(keras.layers.Layer):
# Substitution for PyTorch nn.Conv2D
def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
super(tf_Conv2d, self).__init__()
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
self.conv = keras.layers.Conv2D(
c2, k, s, 'VALID', use_bias=bias,
kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()),
bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None, )

def call(self, inputs):
return self.conv(inputs)


class tf_BottleneckCSP(keras.layers.Layer):
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
# ch_in, ch_out, number, shortcut, groups, expansion
super(tf_BottleneckCSP, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = tf_Conv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
self.cv3 = tf_Conv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
self.cv4 = tf_Conv(2 * c_, c2, 1, 1, w=w.cv4)
self.bn = tf_BN(w.bn)
self.act = lambda x: keras.activations.relu(x, alpha=0.1)
self.m = keras.Sequential([tf_Bottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])

def call(self, inputs):
y1 = self.cv3(self.m(self.cv1(inputs)))
y2 = self.cv2(inputs)
return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))


class tf_C3(keras.layers.Layer):
# CSP Bottleneck with 3 convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
# ch_in, ch_out, number, shortcut, groups, expansion
super(tf_C3, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = tf_Conv(c1, c_, 1, 1, w=w.cv2)
self.cv3 = tf_Conv(2 * c_, c2, 1, 1, w=w.cv3)
self.m = keras.Sequential([tf_Bottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])

def call(self, inputs):
return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))


class tf_SPP(keras.layers.Layer):
# Spatial pyramid pooling layer used in YOLOv3-SPP
def __init__(self, c1, c2, k=(5, 9, 13), w=None):
super(tf_SPP, self).__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = tf_Conv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]

def call(self, inputs):
x = self.cv1(inputs)
return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))


class tf_Detect(keras.layers.Layer):
def __init__(self, nc=80, anchors=(), ch=(), w=None): # detection layer
super(tf_Detect, self).__init__()
self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
self.nc = nc # number of classes
self.no = nc + 5 # number of outputs per anchor
self.nl = len(anchors) # number of detection layers
self.na = len(anchors[0]) // 2 # number of anchors
self.grid = [tf.zeros(1)] * self.nl # init grid
self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
self.anchor_grid = tf.reshape(tf.convert_to_tensor(w.anchor_grid.numpy(), dtype=tf.float32),
[self.nl, 1, -1, 1, 2])
self.m = [tf_Conv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
self.export = False # onnx export
self.training = True # set to False after building model
for i in range(self.nl):
ny, nx = opt.img_size[0] // self.stride[i], opt.img_size[1] // self.stride[i]
self.grid[i] = self._make_grid(nx, ny)

def call(self, inputs):
# x = x.copy() # for profiling
z = [] # inference output
self.training |= self.export
x = []
for i in range(self.nl):
x.append(self.m[i](inputs[i]))
# x(bs,20,20,255) to x(bs,3,20,20,85)
ny, nx = opt.img_size[0] // self.stride[i], opt.img_size[1] // self.stride[i]
x[i] = tf.transpose(tf.reshape(x[i], [-1, ny * nx, self.na, self.no]), [0, 2, 1, 3])

if not self.training: # inference
y = tf.sigmoid(x[i])
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]
# Normalize xywh to 0-1 to reduce calibration error
xy /= tf.constant([[opt.img_size[1], opt.img_size[0]]], dtype=tf.float32)
wh /= tf.constant([[opt.img_size[1], opt.img_size[0]]], dtype=tf.float32)
y = tf.concat([xy, wh, y[..., 4:]], -1)
z.append(tf.reshape(y, [-1, 3 * ny * nx, self.no]))

return x if self.training else (tf.concat(z, 1), x)

@staticmethod
def _make_grid(nx=20, ny=20):
# yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
# return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)


class tf_Upsample(keras.layers.Layer):
def __init__(self, size, scale_factor, mode, w=None):
super(tf_Upsample, self).__init__()
assert scale_factor == 2, "scale_factor must be 2"
# self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
if opt.tf_raw_resize:
# with default arguments: align_corners=False, half_pixel_centers=False
self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
size=(x.shape[1] * 2, x.shape[2] * 2))
else:
self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * 2, x.shape[2] * 2), method=mode)

def call(self, inputs):
return self.upsample(inputs)


class tf_Concat(keras.layers.Layer):
def __init__(self, dimension=1, w=None):
super(tf_Concat, self).__init__()
assert dimension == 1, "convert only NCHW to NHWC concat"
self.d = 3

def call(self, inputs):
return tf.concat(inputs, self.d)


def parse_model(d, ch, model): # model_dict, input_channels(3)
logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)

layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
m_str = m
m = eval(m) if isinstance(m, str) else m # eval strings
for j, a in enumerate(args):
try:
args[j] = eval(a) if isinstance(a, str) else a # eval strings
except:
pass

n = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
c1, c2 = ch[f], args[0]
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2

args = [c1, c2, *args[1:]]
if m in [BottleneckCSP, C3]:
args.insert(2, n)
n = 1
elif m is nn.BatchNorm2d:
args = [ch[f]]
elif m is Concat:
c2 = sum([ch[-1 if x == -1 else x + 1] for x in f])
elif m is Detect:
args.append([ch[x + 1] for x in f])
if isinstance(args[1], int): # number of anchors
args[1] = [list(range(args[1] * 2))] * len(f)
else:
c2 = ch[f]

tf_m = eval('tf_' + m_str.replace('nn.', ''))
m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \
else tf_m(*args, w=model.model[i]) # module

torch_m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
t = str(m)[8:-2].replace('__main__.', '') # module type
np = sum([x.numel() for x in torch_m_.parameters()]) # number params
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_)
ch.append(c2)
return keras.Sequential(layers), sorted(save)


class tf_Model():
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, model=None): # model, input channels, number of classes
super(tf_Model, self).__init__()
if isinstance(cfg, dict):
self.yaml = cfg # model dict
else: # is *.yaml
import yaml # for torch hub
self.yaml_file = Path(cfg).name
with open(cfg) as f:
self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict

# Define model
if nc and nc != self.yaml['nc']:
print('Overriding %s nc=%g with nc=%g' % (cfg, self.yaml['nc'], nc))
self.yaml['nc'] = nc # override yaml value
self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model) # model, savelist, ch_out

def predict(self, inputs, profile=False):
y = [] # outputs
x = inputs
for i, m in enumerate(self.model.layers):
if m.f != -1: # if not from previous layer
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers

x = m(x) # run
y.append(x if m.i in self.savelist else None) # save output

# Add TensorFlow NMS
if opt.tf_nms:
boxes = xywh2xyxy(x[0][..., :4])
probs = x[0][:, :, 4:5]
classes = x[0][:, :, 5:]
scores = probs * classes
if opt.agnostic_nms:
nms = agnostic_nms_layer()((boxes, classes, scores))
return nms, x[1]
else:
boxes = tf.expand_dims(boxes, 2)
nms = tf.image.combined_non_max_suppression(
boxes, scores, opt.topk_per_class, opt.topk_all, opt.iou_thres, opt.score_thres, clip_boxes=False)
return nms, x[1]

return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]
# x = x[0][0] # [x(1,6300,85), ...] to x(6300,85)
# xywh = x[..., :4] # x(6300,4) boxes
# conf = x[..., 4:5] # x(6300,1) confidences
# cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
# return tf.concat([conf, cls, xywh], 1)


class agnostic_nms_layer(keras.layers.Layer):
# wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
def call(self, input):
return tf.map_fn(agnostic_nms, input,
fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
name='agnostic_nms')


def agnostic_nms(x):
boxes, classes, scores = x
class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
scores_inp = tf.reduce_max(scores, -1)
selected_inds = tf.image.non_max_suppression(
boxes, scores_inp, max_output_size=opt.topk_all, iou_threshold=opt.iou_thres, score_threshold=opt.score_thres)
selected_boxes = tf.gather(boxes, selected_inds)
padded_boxes = tf.pad(selected_boxes,
paddings=[[0, opt.topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
mode="CONSTANT", constant_values=0.0)
selected_scores = tf.gather(scores_inp, selected_inds)
padded_scores = tf.pad(selected_scores,
paddings=[[0, opt.topk_all - tf.shape(selected_boxes)[0]]],
mode="CONSTANT", constant_values=-1.0)
selected_classes = tf.gather(class_inds, selected_inds)
padded_classes = tf.pad(selected_classes,
paddings=[[0, opt.topk_all - tf.shape(selected_boxes)[0]]],
mode="CONSTANT", constant_values=-1.0)
valid_detections = tf.shape(selected_inds)[0]
return padded_boxes, padded_scores, padded_classes, valid_detections


def xywh2xyxy(xywh):
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)


def representative_dataset_gen():
# Representative dataset for use with converter.representative_dataset
n = 0
for path, img, im0s, vid_cap in dataset:
# Get sample input data as a numpy array in a method of your choosing.
n += 1
input = np.transpose(img, [1, 2, 0])
input = np.expand_dims(input, axis=0).astype(np.float32)
input /= 255.0
yield [input]
if n >= opt.ncalib:
break


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='cfg path')
parser.add_argument('--weights', type=str, default='yolov5s.pt', help='weights path')
parser.add_argument('--img-size', nargs='+', type=int, default=[320, 320], help='image size') # height, width
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--dynamic-batch-size', action='store_true', help='dynamic batch size')
parser.add_argument('--source', type=str, default='../data/coco128.yaml', help='dir of images or data.yaml file')
parser.add_argument('--ncalib', type=int, default=100, help='number of calibration images')
parser.add_argument('--tfl-int8', action='store_true', dest='tfl_int8', help='export TFLite int8 model')
parser.add_argument('--tf-nms', action='store_true', dest='tf_nms', help='TF NMS (without TFLite export)')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--tf-raw-resize', action='store_true', dest='tf_raw_resize',
help='use tf.raw_ops.ResizeNearestNeighbor for resize')
parser.add_argument('--topk-per-class', type=int, default=100, help='topk per class to keep in NMS')
parser.add_argument('--topk-all', type=int, default=100, help='topk for all classes to keep in NMS')
parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
parser.add_argument('--score-thres', type=float, default=0.4, help='score threshold for NMS')
opt = parser.parse_args()
opt.cfg = check_file(opt.cfg) # check file
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
print(opt)

# Input
img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size(1,3,320,192) iDetection

# Load PyTorch model
model = attempt_load(opt.weights, map_location=torch.device('cpu'), inplace=True, fuse=False)
model.model[-1].export = False # set Detect() layer export=True
y = model(img) # dry run
nc = y[0].shape[-1] - 5

# TensorFlow saved_model export
try:
print('\nStarting TensorFlow saved_model export with TensorFlow %s...' % tf.__version__)
tf_model = tf_Model(opt.cfg, model=model, nc=nc)
img = tf.zeros((opt.batch_size, *opt.img_size, 3)) # NHWC Input for TensorFlow

m = tf_model.model.layers[-1]
assert isinstance(m, tf_Detect), "the last layer must be Detect"
m.training = False
y = tf_model.predict(img)

inputs = keras.Input(shape=(*opt.img_size, 3), batch_size=None if opt.dynamic_batch_size else opt.batch_size)
keras_model = keras.Model(inputs=inputs, outputs=tf_model.predict(inputs))
keras_model.summary()
path = opt.weights.replace('.pt', '_saved_model') # filename
keras_model.save(path, save_format='tf')
print('TensorFlow saved_model export success, saved as %s' % path)
except Exception as e:
print('TensorFlow saved_model export failure: %s' % e)
traceback.print_exc(file=sys.stdout)

# TensorFlow GraphDef export
try:
print('\nStarting TensorFlow GraphDef export with TensorFlow %s...' % tf.__version__)

# https://github.com/leimao/Frozen_Graph_TensorFlow
full_model = tf.function(lambda x: keras_model(x))
full_model = full_model.get_concrete_function(
tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))

frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
f = opt.weights.replace('.pt', '.pb') # filename
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir=os.path.dirname(f),
name=os.path.basename(f),
as_text=False)

print('TensorFlow GraphDef export success, saved as %s' % f)
except Exception as e:
print('TensorFlow GraphDef export failure: %s' % e)
traceback.print_exc(file=sys.stdout)

# TFLite model export
if not opt.tf_nms:
try:
print('\nStarting TFLite export with TensorFlow %s...' % tf.__version__)

# fp32 TFLite model export ---------------------------------------------------------------------------------
# converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
# converter.allow_custom_ops = False
# converter.experimental_new_converter = True
# tflite_model = converter.convert()
# f = opt.weights.replace('.pt', '.tflite') # filename
# open(f, "wb").write(tflite_model)

# fp16 TFLite model export ---------------------------------------------------------------------------------
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# converter.target_spec.supported_types = [tf.float16]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
converter.allow_custom_ops = False
converter.experimental_new_converter = True
tflite_model = converter.convert()
f = opt.weights.replace('.pt', '-fp16.tflite') # filename
open(f, "wb").write(tflite_model)
print('\nTFLite export success, saved as %s' % f)

# int8 TFLite model export ---------------------------------------------------------------------------------
if opt.tfl_int8:
# Representative Dataset
if opt.source.endswith('.yaml'):
with open(check_file(opt.source)) as f:
data = yaml.load(f, Loader=yaml.FullLoader) # data dict
check_dataset(data) # check
opt.source = data['train']
dataset = LoadImages(opt.source, img_size=opt.img_size, auto=False)
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8 # or tf.int8
converter.inference_output_type = tf.uint8 # or tf.int8
converter.allow_custom_ops = False
converter.experimental_new_converter = True
converter.experimental_new_quantizer = False
tflite_model = converter.convert()
f = opt.weights.replace('.pt', '-int8.tflite') # filename
open(f, "wb").write(tflite_model)
print('\nTFLite (int8) export success, saved as %s' % f)

except Exception as e:
print('\nTFLite export failure: %s' % e)
traceback.print_exc(file=sys.stdout)

+ 1
- 0
requirements.txt View File

# coremltools>=4.1 # coremltools>=4.1
# onnx>=1.9.0 # onnx>=1.9.0
# scikit-learn==0.19.2 # for coreml quantization # scikit-learn==0.19.2 # for coreml quantization
# tensorflow==2.4.1 # for TFLite export


# extras -------------------------------------- # extras --------------------------------------
# Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172 # Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172

+ 7
- 5
utils/datasets.py View File





class LoadImages: # for inference class LoadImages: # for inference
def __init__(self, path, img_size=640, stride=32):
def __init__(self, path, img_size=640, stride=32, auto=True):
p = str(Path(path).absolute()) # os-agnostic absolute path p = str(Path(path).absolute()) # os-agnostic absolute path
if '*' in p: if '*' in p:
files = sorted(glob.glob(p, recursive=True)) # glob files = sorted(glob.glob(p, recursive=True)) # glob
self.nf = ni + nv # number of files self.nf = ni + nv # number of files
self.video_flag = [False] * ni + [True] * nv self.video_flag = [False] * ni + [True] * nv
self.mode = 'image' self.mode = 'image'
self.auto = auto
if any(videos): if any(videos):
self.new_video(videos[0]) # new video self.new_video(videos[0]) # new video
else: else:
print(f'image {self.count}/{self.nf} {path}: ', end='') print(f'image {self.count}/{self.nf} {path}: ', end='')


# Padded resize # Padded resize
img = letterbox(img0, self.img_size, stride=self.stride)[0]
img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]


# Convert # Convert
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB




class LoadStreams: # multiple IP or RTSP cameras class LoadStreams: # multiple IP or RTSP cameras
def __init__(self, sources='streams.txt', img_size=640, stride=32):
def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
self.mode = 'stream' self.mode = 'stream'
self.img_size = img_size self.img_size = img_size
self.stride = stride self.stride = stride
n = len(sources) n = len(sources)
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
self.sources = [clean_str(x) for x in sources] # clean source names for later self.sources = [clean_str(x) for x in sources] # clean source names for later
self.auto = auto
for i, s in enumerate(sources): # index, source for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream # Start thread to read frames from video stream
print(f'{i + 1}/{n}: {s}... ', end='') print(f'{i + 1}/{n}: {s}... ', end='')
print('') # newline print('') # newline


# check for common shapes # check for common shapes
s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs], 0) # shapes
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
if not self.rect: if not self.rect:
print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.') print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')


# Letterbox # Letterbox
img0 = self.imgs.copy() img0 = self.imgs.copy()
img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0]


# Stack # Stack
img = np.stack(img, 0) img = np.stack(img, 0)

Loading…
Cancel
Save