Browse Source

New TensorFlow `TFDWConv()` module (#7824)

* New TensorFlow `TFDWConv()` module

Analog to DWConv() module:
8aa2085a7e/models/common.py (L53-L57)

* Fix and new activations() function

* Update tf.py
modifyDataloader
Glenn Jocher GitHub 2 years ago
parent
commit
4d59f65db5
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 41 additions and 18 deletions
  1. +41
    -18
      models/tf.py

+ 41
- 18
models/tf.py View File

@@ -70,25 +70,38 @@ class TFConv(keras.layers.Layer):
# see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch

conv = keras.layers.Conv2D(
c2,
k,
s,
'SAME' if s == 1 else 'VALID',
use_bias=False if hasattr(w, 'bn') else True,
filters=c2,
kernel_size=k,
strides=s,
padding='SAME' if s == 1 else 'VALID',
use_bias=not hasattr(w, 'bn'),
kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
self.act = activations(w.act) if act else tf.identity

def call(self, inputs):
return self.act(self.bn(self.conv(inputs)))

# YOLOv5 activations
if isinstance(w.act, nn.LeakyReLU):
self.act = (lambda x: keras.activations.relu(x, alpha=0.1)) if act else tf.identity
elif isinstance(w.act, nn.Hardswish):
self.act = (lambda x: x * tf.nn.relu6(x + 3) * 0.166666667) if act else tf.identity
elif isinstance(w.act, (nn.SiLU, SiLU)):
self.act = (lambda x: keras.activations.swish(x)) if act else tf.identity
else:
raise Exception(f'no matching TensorFlow activation found for {w.act}')

class TFDWConv(keras.layers.Layer):
# Depthwise convolution
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
# ch_in, ch_out, weights, kernel, stride, padding, groups
super().__init__()
assert isinstance(k, int), "Convolution with multiple kernels are not allowed."

conv = keras.layers.DepthwiseConv2D(
kernel_size=k,
strides=s,
padding='SAME' if s == 1 else 'VALID',
use_bias=not hasattr(w, 'bn'),
kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
self.act = activations(w.act) if act else tf.identity

def call(self, inputs):
return self.act(self.bn(self.conv(inputs)))
@@ -103,10 +116,8 @@ class TFFocus(keras.layers.Layer):

def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)
# inputs = inputs / 255 # normalize 0-255 to 0-1
return self.conv(
tf.concat(
[inputs[:, ::2, ::2, :], inputs[:, 1::2, ::2, :], inputs[:, ::2, 1::2, :], inputs[:, 1::2, 1::2, :]],
3))
inputs = [inputs[:, ::2, ::2, :], inputs[:, 1::2, ::2, :], inputs[:, ::2, 1::2, :], inputs[:, 1::2, 1::2, :]]
return self.conv(tf.concat(inputs, 3))


class TFBottleneck(keras.layers.Layer):
@@ -439,6 +450,18 @@ class AgnosticNMS(keras.layers.Layer):
return padded_boxes, padded_scores, padded_classes, valid_detections


def activations(act=nn.SiLU):
# Returns TF activation from input PyTorch activation
if isinstance(act, nn.LeakyReLU):
return lambda x: keras.activations.relu(x, alpha=0.1)
elif isinstance(act, nn.Hardswish):
return lambda x: x * tf.nn.relu6(x + 3) * 0.166666667
elif isinstance(act, (nn.SiLU, SiLU)):
return lambda x: keras.activations.swish(x)
else:
raise Exception(f'no matching TensorFlow activation found for PyTorch activation {act}')


def representative_dataset_gen(dataset, ncalib=100):
# Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays
for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):

Loading…
Cancel
Save