Ver código fonte

Add TensorFlow and TFLite export (#1127)

* Add models/tf.py for TensorFlow and TFLite export

* Set auto=False for int8 calibration

* Update requirements.txt for TensorFlow and TFLite export

* Read anchors directly from PyTorch weights

* Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export

* Remove check_anchor_order, check_file, set_logging from import

* Reformat code and optimize imports

* Autodownload model and check cfg

* update --source path, img-size to 320, single output

* Adjust representative_dataset

* Put representative dataset in tfl_int8 block

* detect.py TF inference

* weights to string

* weights to string

* cleanup tf.py

* Add --dynamic-batch-size

* Add xywh normalization to reduce calibration error

* Update requirements.txt

TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error

* Fix imports

Move C3 from models.experimental to models.common

* Add models/tf.py for TensorFlow and TFLite export

* Set auto=False for int8 calibration

* Update requirements.txt for TensorFlow and TFLite export

* Read anchors directly from PyTorch weights

* Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export

* Remove check_anchor_order, check_file, set_logging from import

* Reformat code and optimize imports

* Autodownload model and check cfg

* update --source path, img-size to 320, single output

* Adjust representative_dataset

* detect.py TF inference

* Put representative dataset in tfl_int8 block

* weights to string

* weights to string

* cleanup tf.py

* Add --dynamic-batch-size

* Add xywh normalization to reduce calibration error

* Update requirements.txt

TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error

* Fix imports

Move C3 from models.experimental to models.common

* implement C3() and SiLU()

* Fix reshape dim to support dynamic batching

* Add epsilon argument in tf_BN, which is different between TF and PT

* Set stride to None if not using PyTorch, and do not warmup without PyTorch

* Add list support in check_img_size()

* Add list input support in detect.py

* sys.path.append('./') to run from yolov5/

* Add int8 quantization support for TensorFlow 2.5

* Add get_coco128.sh

* Remove --no-tfl-detect in models/tf.py (Use tf-android-tfl-detect branch for EdgeTPU)

* Update requirements.txt

* Replace torch.load() with attempt_load()

* Update requirements.txt

* Add --tf-raw-resize to set half_pixel_centers=False

* Add --agnostic-nms for TF class-agnostic NMS

* Cleanup after merge

* Cleanup2 after merge

* Cleanup3 after merge

* Add tf.py docstring with credit and usage

* pb saved_model and tflite use only one model in detect.py

* Add use cases in docstring of tf.py

* Remove redundant `stride` definition

* Remove keras direct import

* Fix `check_requirements(('tensorflow>=2.4.1',))`

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
modifyDataloader
Jiacong Fang GitHub 3 anos atrás
pai
commit
808bcad3bb
Nenhuma chave conhecida encontrada para esta assinatura no banco de dados ID da chave GPG: 4AEE18F83AFDEB23
5 arquivos alterados com 626 adições e 17 exclusões
  1. +54
    -10
      detect.py
  2. +6
    -2
      models/experimental.py
  3. +558
    -0
      models/tf.py
  4. +1
    -0
      requirements.txt
  5. +7
    -5
      utils/datasets.py

+ 54
- 10
detect.py Ver arquivo

@@ -12,6 +12,7 @@ import time
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn

@@ -51,6 +52,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
hide_labels=False, # hide labels
hide_conf=False, # hide confidences
half=False, # use FP16 half-precision inference
tfl_int8=False, # INT8 quantized TFLite model
):
save_img = not nosave and not source.endswith('.txt') # save inference images
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
@@ -68,7 +70,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
# Load model
w = weights[0] if isinstance(weights, list) else weights
classify, suffix = False, Path(w).suffix.lower()
pt, onnx, tflite, pb, graph_def = (suffix == x for x in ['.pt', '.onnx', '.tflite', '.pb', '']) # backend
pt, onnx, tflite, pb, saved_model = (suffix == x for x in ['.pt', '.onnx', '.tflite', '.pb', '']) # backend
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
if pt:
model = attempt_load(weights, map_location=device) # load FP32 model
@@ -83,30 +85,49 @@ def run(weights='yolov5s.pt', # model.pt path(s)
check_requirements(('onnx', 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
else: # TensorFlow models
check_requirements(('tensorflow>=2.4.1',))
import tensorflow as tf
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped import
return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
tf.nest.map_structure(x.graph.as_graph_element, outputs))

graph_def = tf.Graph().as_graph_def()
graph_def.ParseFromString(open(w, 'rb').read())
frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
elif saved_model:
model = tf.keras.models.load_model(w)
elif tflite:
interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
imgsz = check_img_size(imgsz, s=stride) # check image size

# Dataloader
if webcam:
view_img = check_imshow()
cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz, stride=stride)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
bs = len(dataset) # batch_size
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride)
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
bs = 1 # batch_size
vid_path, vid_writer = [None] * bs, [None] * bs

# Run inference
if pt and device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters()))) # run once
t0 = time.time()
for path, img, im0s, vid_cap in dataset:
if pt:
if onnx:
img = img.astype('float32')
else:
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # uint8 to fp16/32
elif onnx:
img = img.astype('float32')
img /= 255.0 # 0 - 255 to 0.0 - 1.0
img = img / 255.0 # 0 - 255 to 0.0 - 1.0
if len(img.shape) == 3:
img = img[None] # expand for batch dim

@@ -117,6 +138,27 @@ def run(weights='yolov5s.pt', # model.pt path(s)
pred = model(img, augment=augment, visualize=visualize)[0]
elif onnx:
pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
else: # tensorflow model (tflite, pb, saved_model)
imn = img.permute(0, 2, 3, 1).cpu().numpy() # image in numpy
if pb:
pred = frozen_func(x=tf.constant(imn)).numpy()
elif saved_model:
pred = model(imn, training=False).numpy()
elif tflite:
if tfl_int8:
scale, zero_point = input_details[0]['quantization']
imn = (imn / scale + zero_point).astype(np.uint8)
interpreter.set_tensor(input_details[0]['index'], imn)
interpreter.invoke()
pred = interpreter.get_tensor(output_details[0]['index'])
if tfl_int8:
scale, zero_point = output_details[0]['quantization']
pred = (pred.astype(np.float32) - zero_point) * scale
pred[..., 0] *= imgsz[1] # x
pred[..., 1] *= imgsz[0] # y
pred[..., 2] *= imgsz[1] # w
pred[..., 3] *= imgsz[0] # h
pred = torch.tensor(pred)

# NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
@@ -202,9 +244,9 @@ def run(weights='yolov5s.pt', # model.pt path(s)

def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pb', help='model.pt path(s)')
parser.add_argument('--source', type=str, default='data/images', help='file/dir/URL/glob, 0 for webcam')
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
@@ -226,7 +268,9 @@ def parse_opt():
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--tfl-int8', action='store_true', help='INT8 quantized TFLite model')
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
return opt



+ 6
- 2
models/experimental.py Ver arquivo

@@ -85,14 +85,18 @@ class Ensemble(nn.ModuleList):
return y, None # inference, train output


def attempt_load(weights, map_location=None, inplace=True):
def attempt_load(weights, map_location=None, inplace=True, fuse=True):
from models.yolo import Detect, Model

# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w), map_location=map_location) # load
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
if fuse:
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
else:
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # without layer fuse


# Compatibility updates
for m in model.modules():

+ 558
- 0
models/tf.py Ver arquivo

@@ -0,0 +1,558 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
TensorFlow/Keras and TFLite versions of YOLOv5
Authored by https://github.com/zldrobit in PR https://github.com/ultralytics/yolov5/pull/1127

Usage:
$ python models/tf.py --weights yolov5s.pt --cfg yolov5s.yaml

Export int8 TFLite models:
$ python models/tf.py --weights yolov5s.pt --cfg models/yolov5s.yaml --tfl-int8 \
--source path/to/images/ --ncalib 100

Detection:
$ python detect.py --weights yolov5s.pb --img 320
$ python detect.py --weights yolov5s_saved_model --img 320
$ python detect.py --weights yolov5s-fp16.tflite --img 320
$ python detect.py --weights yolov5s-int8.tflite --img 320 --tfl-int8

For TensorFlow.js:
$ python models/tf.py --weights yolov5s.pt --cfg models/yolov5s.yaml --img 320 --tf-nms --agnostic-nms
$ pip install tensorflowjs
$ tensorflowjs_converter \
--input_format=tf_frozen_model \
--output_node_names='Identity,Identity_1,Identity_2,Identity_3' \
yolov5s.pb \
web_model
$ # Edit web_model/model.json to sort Identity* in ascending order
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
$ npm install
$ ln -s ../../yolov5/web_model public/web_model
$ npm start
"""

import argparse
import logging
import os
import sys
import traceback
from copy import deepcopy
from pathlib import Path

sys.path.append('./') # to run '$ python *.py' files in subdirectories

import numpy as np
import tensorflow as tf
import torch
import torch.nn as nn
import yaml
from tensorflow import keras
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, autopad, C3
from models.experimental import MixConv2d, CrossConv, attempt_load
from models.yolo import Detect
from utils.datasets import LoadImages
from utils.general import make_divisible, check_file, check_dataset

logger = logging.getLogger(__name__)


class tf_BN(keras.layers.Layer):
# TensorFlow BatchNormalization wrapper
def __init__(self, w=None):
super(tf_BN, self).__init__()
self.bn = keras.layers.BatchNormalization(
beta_initializer=keras.initializers.Constant(w.bias.numpy()),
gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),
moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),
epsilon=w.eps)

def call(self, inputs):
return self.bn(inputs)


class tf_Pad(keras.layers.Layer):
def __init__(self, pad):
super(tf_Pad, self).__init__()
self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])

def call(self, inputs):
return tf.pad(inputs, self.pad, mode='constant', constant_values=0)


class tf_Conv(keras.layers.Layer):
# Standard convolution
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
# ch_in, ch_out, weights, kernel, stride, padding, groups
super(tf_Conv, self).__init__()
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
assert isinstance(k, int), "Convolution with multiple kernels are not allowed."
# TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
# see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch

conv = keras.layers.Conv2D(
c2, k, s, 'SAME' if s == 1 else 'VALID', use_bias=False,
kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()))
self.conv = conv if s == 1 else keras.Sequential([tf_Pad(autopad(k, p)), conv])
self.bn = tf_BN(w.bn) if hasattr(w, 'bn') else tf.identity

# YOLOv5 activations
if isinstance(w.act, nn.LeakyReLU):
self.act = (lambda x: keras.activations.relu(x, alpha=0.1)) if act else tf.identity
elif isinstance(w.act, nn.Hardswish):
self.act = (lambda x: x * tf.nn.relu6(x + 3) * 0.166666667) if act else tf.identity
elif isinstance(w.act, nn.SiLU):
self.act = (lambda x: keras.activations.swish(x)) if act else tf.identity

def call(self, inputs):
return self.act(self.bn(self.conv(inputs)))


class tf_Focus(keras.layers.Layer):
# Focus wh information into c-space
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
# ch_in, ch_out, kernel, stride, padding, groups
super(tf_Focus, self).__init__()
self.conv = tf_Conv(c1 * 4, c2, k, s, p, g, act, w.conv)

def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)
# inputs = inputs / 255. # normalize 0-255 to 0-1
return self.conv(tf.concat([inputs[:, ::2, ::2, :],
inputs[:, 1::2, ::2, :],
inputs[:, ::2, 1::2, :],
inputs[:, 1::2, 1::2, :]], 3))


class tf_Bottleneck(keras.layers.Layer):
# Standard bottleneck
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, shortcut, groups, expansion
super(tf_Bottleneck, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = tf_Conv(c_, c2, 3, 1, g=g, w=w.cv2)
self.add = shortcut and c1 == c2

def call(self, inputs):
return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))


class tf_Conv2d(keras.layers.Layer):
# Substitution for PyTorch nn.Conv2D
def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
super(tf_Conv2d, self).__init__()
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
self.conv = keras.layers.Conv2D(
c2, k, s, 'VALID', use_bias=bias,
kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()),
bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None, )

def call(self, inputs):
return self.conv(inputs)


class tf_BottleneckCSP(keras.layers.Layer):
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
# ch_in, ch_out, number, shortcut, groups, expansion
super(tf_BottleneckCSP, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = tf_Conv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
self.cv3 = tf_Conv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
self.cv4 = tf_Conv(2 * c_, c2, 1, 1, w=w.cv4)
self.bn = tf_BN(w.bn)
self.act = lambda x: keras.activations.relu(x, alpha=0.1)
self.m = keras.Sequential([tf_Bottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])

def call(self, inputs):
y1 = self.cv3(self.m(self.cv1(inputs)))
y2 = self.cv2(inputs)
return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))


class tf_C3(keras.layers.Layer):
# CSP Bottleneck with 3 convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
# ch_in, ch_out, number, shortcut, groups, expansion
super(tf_C3, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = tf_Conv(c1, c_, 1, 1, w=w.cv2)
self.cv3 = tf_Conv(2 * c_, c2, 1, 1, w=w.cv3)
self.m = keras.Sequential([tf_Bottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])

def call(self, inputs):
return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))


class tf_SPP(keras.layers.Layer):
# Spatial pyramid pooling layer used in YOLOv3-SPP
def __init__(self, c1, c2, k=(5, 9, 13), w=None):
super(tf_SPP, self).__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = tf_Conv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]

def call(self, inputs):
x = self.cv1(inputs)
return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))


class tf_Detect(keras.layers.Layer):
def __init__(self, nc=80, anchors=(), ch=(), w=None): # detection layer
super(tf_Detect, self).__init__()
self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
self.nc = nc # number of classes
self.no = nc + 5 # number of outputs per anchor
self.nl = len(anchors) # number of detection layers
self.na = len(anchors[0]) // 2 # number of anchors
self.grid = [tf.zeros(1)] * self.nl # init grid
self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
self.anchor_grid = tf.reshape(tf.convert_to_tensor(w.anchor_grid.numpy(), dtype=tf.float32),
[self.nl, 1, -1, 1, 2])
self.m = [tf_Conv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
self.export = False # onnx export
self.training = True # set to False after building model
for i in range(self.nl):
ny, nx = opt.img_size[0] // self.stride[i], opt.img_size[1] // self.stride[i]
self.grid[i] = self._make_grid(nx, ny)

def call(self, inputs):
# x = x.copy() # for profiling
z = [] # inference output
self.training |= self.export
x = []
for i in range(self.nl):
x.append(self.m[i](inputs[i]))
# x(bs,20,20,255) to x(bs,3,20,20,85)
ny, nx = opt.img_size[0] // self.stride[i], opt.img_size[1] // self.stride[i]
x[i] = tf.transpose(tf.reshape(x[i], [-1, ny * nx, self.na, self.no]), [0, 2, 1, 3])

if not self.training: # inference
y = tf.sigmoid(x[i])
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]
# Normalize xywh to 0-1 to reduce calibration error
xy /= tf.constant([[opt.img_size[1], opt.img_size[0]]], dtype=tf.float32)
wh /= tf.constant([[opt.img_size[1], opt.img_size[0]]], dtype=tf.float32)
y = tf.concat([xy, wh, y[..., 4:]], -1)
z.append(tf.reshape(y, [-1, 3 * ny * nx, self.no]))

return x if self.training else (tf.concat(z, 1), x)

@staticmethod
def _make_grid(nx=20, ny=20):
# yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
# return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)


class tf_Upsample(keras.layers.Layer):
def __init__(self, size, scale_factor, mode, w=None):
super(tf_Upsample, self).__init__()
assert scale_factor == 2, "scale_factor must be 2"
# self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
if opt.tf_raw_resize:
# with default arguments: align_corners=False, half_pixel_centers=False
self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
size=(x.shape[1] * 2, x.shape[2] * 2))
else:
self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * 2, x.shape[2] * 2), method=mode)

def call(self, inputs):
return self.upsample(inputs)


class tf_Concat(keras.layers.Layer):
def __init__(self, dimension=1, w=None):
super(tf_Concat, self).__init__()
assert dimension == 1, "convert only NCHW to NHWC concat"
self.d = 3

def call(self, inputs):
return tf.concat(inputs, self.d)


def parse_model(d, ch, model): # model_dict, input_channels(3)
logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)

layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
m_str = m
m = eval(m) if isinstance(m, str) else m # eval strings
for j, a in enumerate(args):
try:
args[j] = eval(a) if isinstance(a, str) else a # eval strings
except:
pass

n = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
c1, c2 = ch[f], args[0]
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2

args = [c1, c2, *args[1:]]
if m in [BottleneckCSP, C3]:
args.insert(2, n)
n = 1
elif m is nn.BatchNorm2d:
args = [ch[f]]
elif m is Concat:
c2 = sum([ch[-1 if x == -1 else x + 1] for x in f])
elif m is Detect:
args.append([ch[x + 1] for x in f])
if isinstance(args[1], int): # number of anchors
args[1] = [list(range(args[1] * 2))] * len(f)
else:
c2 = ch[f]

tf_m = eval('tf_' + m_str.replace('nn.', ''))
m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \
else tf_m(*args, w=model.model[i]) # module

torch_m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
t = str(m)[8:-2].replace('__main__.', '') # module type
np = sum([x.numel() for x in torch_m_.parameters()]) # number params
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_)
ch.append(c2)
return keras.Sequential(layers), sorted(save)


class tf_Model():
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, model=None): # model, input channels, number of classes
super(tf_Model, self).__init__()
if isinstance(cfg, dict):
self.yaml = cfg # model dict
else: # is *.yaml
import yaml # for torch hub
self.yaml_file = Path(cfg).name
with open(cfg) as f:
self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict

# Define model
if nc and nc != self.yaml['nc']:
print('Overriding %s nc=%g with nc=%g' % (cfg, self.yaml['nc'], nc))
self.yaml['nc'] = nc # override yaml value
self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model) # model, savelist, ch_out

def predict(self, inputs, profile=False):
y = [] # outputs
x = inputs
for i, m in enumerate(self.model.layers):
if m.f != -1: # if not from previous layer
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers

x = m(x) # run
y.append(x if m.i in self.savelist else None) # save output

# Add TensorFlow NMS
if opt.tf_nms:
boxes = xywh2xyxy(x[0][..., :4])
probs = x[0][:, :, 4:5]
classes = x[0][:, :, 5:]
scores = probs * classes
if opt.agnostic_nms:
nms = agnostic_nms_layer()((boxes, classes, scores))
return nms, x[1]
else:
boxes = tf.expand_dims(boxes, 2)
nms = tf.image.combined_non_max_suppression(
boxes, scores, opt.topk_per_class, opt.topk_all, opt.iou_thres, opt.score_thres, clip_boxes=False)
return nms, x[1]

return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]
# x = x[0][0] # [x(1,6300,85), ...] to x(6300,85)
# xywh = x[..., :4] # x(6300,4) boxes
# conf = x[..., 4:5] # x(6300,1) confidences
# cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
# return tf.concat([conf, cls, xywh], 1)


class agnostic_nms_layer(keras.layers.Layer):
# wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
def call(self, input):
return tf.map_fn(agnostic_nms, input,
fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
name='agnostic_nms')


def agnostic_nms(x):
boxes, classes, scores = x
class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
scores_inp = tf.reduce_max(scores, -1)
selected_inds = tf.image.non_max_suppression(
boxes, scores_inp, max_output_size=opt.topk_all, iou_threshold=opt.iou_thres, score_threshold=opt.score_thres)
selected_boxes = tf.gather(boxes, selected_inds)
padded_boxes = tf.pad(selected_boxes,
paddings=[[0, opt.topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
mode="CONSTANT", constant_values=0.0)
selected_scores = tf.gather(scores_inp, selected_inds)
padded_scores = tf.pad(selected_scores,
paddings=[[0, opt.topk_all - tf.shape(selected_boxes)[0]]],
mode="CONSTANT", constant_values=-1.0)
selected_classes = tf.gather(class_inds, selected_inds)
padded_classes = tf.pad(selected_classes,
paddings=[[0, opt.topk_all - tf.shape(selected_boxes)[0]]],
mode="CONSTANT", constant_values=-1.0)
valid_detections = tf.shape(selected_inds)[0]
return padded_boxes, padded_scores, padded_classes, valid_detections


def xywh2xyxy(xywh):
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)


def representative_dataset_gen():
# Representative dataset for use with converter.representative_dataset
n = 0
for path, img, im0s, vid_cap in dataset:
# Get sample input data as a numpy array in a method of your choosing.
n += 1
input = np.transpose(img, [1, 2, 0])
input = np.expand_dims(input, axis=0).astype(np.float32)
input /= 255.0
yield [input]
if n >= opt.ncalib:
break


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='cfg path')
parser.add_argument('--weights', type=str, default='yolov5s.pt', help='weights path')
parser.add_argument('--img-size', nargs='+', type=int, default=[320, 320], help='image size') # height, width
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--dynamic-batch-size', action='store_true', help='dynamic batch size')
parser.add_argument('--source', type=str, default='../data/coco128.yaml', help='dir of images or data.yaml file')
parser.add_argument('--ncalib', type=int, default=100, help='number of calibration images')
parser.add_argument('--tfl-int8', action='store_true', dest='tfl_int8', help='export TFLite int8 model')
parser.add_argument('--tf-nms', action='store_true', dest='tf_nms', help='TF NMS (without TFLite export)')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--tf-raw-resize', action='store_true', dest='tf_raw_resize',
help='use tf.raw_ops.ResizeNearestNeighbor for resize')
parser.add_argument('--topk-per-class', type=int, default=100, help='topk per class to keep in NMS')
parser.add_argument('--topk-all', type=int, default=100, help='topk for all classes to keep in NMS')
parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
parser.add_argument('--score-thres', type=float, default=0.4, help='score threshold for NMS')
opt = parser.parse_args()
opt.cfg = check_file(opt.cfg) # check file
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
print(opt)

# Input
img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size(1,3,320,192) iDetection

# Load PyTorch model
model = attempt_load(opt.weights, map_location=torch.device('cpu'), inplace=True, fuse=False)
model.model[-1].export = False # set Detect() layer export=True
y = model(img) # dry run
nc = y[0].shape[-1] - 5

# TensorFlow saved_model export
try:
print('\nStarting TensorFlow saved_model export with TensorFlow %s...' % tf.__version__)
tf_model = tf_Model(opt.cfg, model=model, nc=nc)
img = tf.zeros((opt.batch_size, *opt.img_size, 3)) # NHWC Input for TensorFlow

m = tf_model.model.layers[-1]
assert isinstance(m, tf_Detect), "the last layer must be Detect"
m.training = False
y = tf_model.predict(img)

inputs = keras.Input(shape=(*opt.img_size, 3), batch_size=None if opt.dynamic_batch_size else opt.batch_size)
keras_model = keras.Model(inputs=inputs, outputs=tf_model.predict(inputs))
keras_model.summary()
path = opt.weights.replace('.pt', '_saved_model') # filename
keras_model.save(path, save_format='tf')
print('TensorFlow saved_model export success, saved as %s' % path)
except Exception as e:
print('TensorFlow saved_model export failure: %s' % e)
traceback.print_exc(file=sys.stdout)

# TensorFlow GraphDef export
try:
print('\nStarting TensorFlow GraphDef export with TensorFlow %s...' % tf.__version__)

# https://github.com/leimao/Frozen_Graph_TensorFlow
full_model = tf.function(lambda x: keras_model(x))
full_model = full_model.get_concrete_function(
tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))

frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
f = opt.weights.replace('.pt', '.pb') # filename
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir=os.path.dirname(f),
name=os.path.basename(f),
as_text=False)

print('TensorFlow GraphDef export success, saved as %s' % f)
except Exception as e:
print('TensorFlow GraphDef export failure: %s' % e)
traceback.print_exc(file=sys.stdout)

# TFLite model export
if not opt.tf_nms:
try:
print('\nStarting TFLite export with TensorFlow %s...' % tf.__version__)

# fp32 TFLite model export ---------------------------------------------------------------------------------
# converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
# converter.allow_custom_ops = False
# converter.experimental_new_converter = True
# tflite_model = converter.convert()
# f = opt.weights.replace('.pt', '.tflite') # filename
# open(f, "wb").write(tflite_model)

# fp16 TFLite model export ---------------------------------------------------------------------------------
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# converter.target_spec.supported_types = [tf.float16]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
converter.allow_custom_ops = False
converter.experimental_new_converter = True
tflite_model = converter.convert()
f = opt.weights.replace('.pt', '-fp16.tflite') # filename
open(f, "wb").write(tflite_model)
print('\nTFLite export success, saved as %s' % f)

# int8 TFLite model export ---------------------------------------------------------------------------------
if opt.tfl_int8:
# Representative Dataset
if opt.source.endswith('.yaml'):
with open(check_file(opt.source)) as f:
data = yaml.load(f, Loader=yaml.FullLoader) # data dict
check_dataset(data) # check
opt.source = data['train']
dataset = LoadImages(opt.source, img_size=opt.img_size, auto=False)
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8 # or tf.int8
converter.inference_output_type = tf.uint8 # or tf.int8
converter.allow_custom_ops = False
converter.experimental_new_converter = True
converter.experimental_new_quantizer = False
tflite_model = converter.convert()
f = opt.weights.replace('.pt', '-int8.tflite') # filename
open(f, "wb").write(tflite_model)
print('\nTFLite (int8) export success, saved as %s' % f)

except Exception as e:
print('\nTFLite export failure: %s' % e)
traceback.print_exc(file=sys.stdout)

+ 1
- 0
requirements.txt Ver arquivo

@@ -23,6 +23,7 @@ pandas
# coremltools>=4.1
# onnx>=1.9.0
# scikit-learn==0.19.2 # for coreml quantization
# tensorflow==2.4.1 # for TFLite export

# extras --------------------------------------
# Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172

+ 7
- 5
utils/datasets.py Ver arquivo

@@ -155,7 +155,7 @@ class _RepeatSampler(object):


class LoadImages: # for inference
def __init__(self, path, img_size=640, stride=32):
def __init__(self, path, img_size=640, stride=32, auto=True):
p = str(Path(path).absolute()) # os-agnostic absolute path
if '*' in p:
files = sorted(glob.glob(p, recursive=True)) # glob
@@ -176,6 +176,7 @@ class LoadImages: # for inference
self.nf = ni + nv # number of files
self.video_flag = [False] * ni + [True] * nv
self.mode = 'image'
self.auto = auto
if any(videos):
self.new_video(videos[0]) # new video
else:
@@ -217,7 +218,7 @@ class LoadImages: # for inference
print(f'image {self.count}/{self.nf} {path}: ', end='')

# Padded resize
img = letterbox(img0, self.img_size, stride=self.stride)[0]
img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]

# Convert
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
@@ -276,7 +277,7 @@ class LoadWebcam: # for inference


class LoadStreams: # multiple IP or RTSP cameras
def __init__(self, sources='streams.txt', img_size=640, stride=32):
def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
self.mode = 'stream'
self.img_size = img_size
self.stride = stride
@@ -290,6 +291,7 @@ class LoadStreams: # multiple IP or RTSP cameras
n = len(sources)
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
self.sources = [clean_str(x) for x in sources] # clean source names for later
self.auto = auto
for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream
print(f'{i + 1}/{n}: {s}... ', end='')
@@ -312,7 +314,7 @@ class LoadStreams: # multiple IP or RTSP cameras
print('') # newline

# check for common shapes
s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs], 0) # shapes
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
if not self.rect:
print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
@@ -341,7 +343,7 @@ class LoadStreams: # multiple IP or RTSP cameras

# Letterbox
img0 = self.imgs.copy()
img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0]

# Stack
img = np.stack(img, 0)

Carregando…
Cancelar
Salvar