Browse Source

New TensorFlow `TFCrossConv()` module (#7827)

* New TensorFlow `TFCrossConv()` module

* Move from experimental to common

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add C3x

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add to C3x to yolo.py

* Add to C3x to tf.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* TFC3x bug fix

* TFC3x bug fix

* TFC3x bug fix

* Add TFDWConv g==c1==c2 check

* Add comment

* Update tf.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
modifyDataloader
Glenn Jocher GitHub 2 years ago
parent
commit
fb7fa5be8b
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 64 additions and 28 deletions
  1. +23
    -2
      models/common.py
  2. +0
    -14
      models/experimental.py
  3. +39
    -10
      models/tf.py
  4. +2
    -2
      models/yolo.py

+ 23
- 2
models/common.py View File

def autopad(k, p=None): # kernel, padding def autopad(k, p=None): # kernel, padding
# Pad to 'same' # Pad to 'same'
if p is None: if p is None:
p = k // 2 if isinstance(k, int) else (x // 2 for x in k) # auto-pad
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p return p




return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1)))) return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))




class CrossConv(nn.Module):
# Cross Convolution Downsample
def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
# ch_in, ch_out, kernel, stride, groups, expansion, shortcut
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, (1, k), (1, s))
self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
self.add = shortcut and c1 == c2

def forward(self, x):
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))


class C3(nn.Module): class C3(nn.Module):
# CSP Bottleneck with 3 convolutions # CSP Bottleneck with 3 convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
self.cv2 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2) self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
# self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)))


def forward(self, x): def forward(self, x):
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))




class C3x(C3):
# C3 module with cross-convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
c_ = int(c2 * e)
self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)))


class C3TR(C3): class C3TR(C3):
# C3 module with TransformerBlock() # C3 module with TransformerBlock()
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):

+ 0
- 14
models/experimental.py View File

from utils.downloads import attempt_download from utils.downloads import attempt_download




class CrossConv(nn.Module):
# Cross Convolution Downsample
def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
# ch_in, ch_out, kernel, stride, groups, expansion, shortcut
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, (1, k), (1, s))
self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
self.add = shortcut and c1 == c2

def forward(self, x):
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))


class Sum(nn.Module): class Sum(nn.Module):
# Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
def __init__(self, n, weight=False): # n: number of inputs def __init__(self, n, weight=False): # n: number of inputs

+ 39
- 10
models/tf.py View File

import torch.nn as nn import torch.nn as nn
from tensorflow import keras from tensorflow import keras


from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, Concat, Conv, DWConv, Focus, autopad
from models.experimental import CrossConv, MixConv2d, attempt_load
from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv, Focus, autopad
from models.experimental import MixConv2d, attempt_load
from models.yolo import Detect from models.yolo import Detect
from utils.activations import SiLU from utils.activations import SiLU
from utils.general import LOGGER, make_divisible, print_args from utils.general import LOGGER, make_divisible, print_args




class TFPad(keras.layers.Layer): class TFPad(keras.layers.Layer):
# Pad inputs in spatial dimensions 1 and 2
def __init__(self, pad): def __init__(self, pad):
super().__init__() super().__init__()
self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
if isinstance(pad, int):
self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
else: # tuple/list
self.pad = tf.constant([[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]])


def call(self, inputs): def call(self, inputs):
return tf.pad(inputs, self.pad, mode='constant', constant_values=0) return tf.pad(inputs, self.pad, mode='constant', constant_values=0)
# ch_in, ch_out, weights, kernel, stride, padding, groups # ch_in, ch_out, weights, kernel, stride, padding, groups
super().__init__() super().__init__()
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument" assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
assert isinstance(k, int), "Convolution with multiple kernels are not allowed."
# TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding) # TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
# see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch # see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch

conv = keras.layers.Conv2D( conv = keras.layers.Conv2D(
filters=c2, filters=c2,
kernel_size=k, kernel_size=k,
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None): def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
# ch_in, ch_out, weights, kernel, stride, padding, groups # ch_in, ch_out, weights, kernel, stride, padding, groups
super().__init__() super().__init__()
assert isinstance(k, int), "Convolution with multiple kernels are not allowed."

assert g == c1 == c2, f'TFDWConv() groups={g} must equal input={c1} and output={c2} channels'
conv = keras.layers.DepthwiseConv2D( conv = keras.layers.DepthwiseConv2D(
kernel_size=k, kernel_size=k,
strides=s, strides=s,
return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs)) return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))




class TFCrossConv(keras.layers.Layer):
# Cross Convolution
def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False, w=None):
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = TFConv(c1, c_, (1, k), (1, s), w=w.cv1)
self.cv2 = TFConv(c_, c2, (k, 1), (s, 1), g=g, w=w.cv2)
self.add = shortcut and c1 == c2

def call(self, inputs):
return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))


class TFConv2d(keras.layers.Layer): class TFConv2d(keras.layers.Layer):
# Substitution for PyTorch nn.Conv2D # Substitution for PyTorch nn.Conv2D
def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None): def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3)) return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))




class TFC3x(keras.layers.Layer):
# 3 module with cross-convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
# ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
self.m = keras.Sequential([
TFCrossConv(c_, c_, k=3, s=1, g=g, e=1.0, shortcut=shortcut, w=w.m[j]) for j in range(n)])

def call(self, inputs):
return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))


class TFSPP(keras.layers.Layer): class TFSPP(keras.layers.Layer):
# Spatial pyramid pooling layer used in YOLOv3-SPP # Spatial pyramid pooling layer used in YOLOv3-SPP
def __init__(self, c1, c2, k=(5, 9, 13), w=None): def __init__(self, c1, c2, k=(5, 9, 13), w=None):
pass pass


n = max(round(n * gd), 1) if n > 1 else n # depth gain n = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3, C3x]:
c1, c2 = ch[f], args[0] c1, c2 = ch[f], args[0]
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2 c2 = make_divisible(c2 * gw, 8) if c2 != no else c2


args = [c1, c2, *args[1:]] args = [c1, c2, *args[1:]]
if m in [BottleneckCSP, C3]:
if m in [BottleneckCSP, C3, C3x]:
args.insert(2, n) args.insert(2, n)
n = 1 n = 1
elif m is nn.BatchNorm2d: elif m is nn.BatchNorm2d:

+ 2
- 2
models/yolo.py View File



n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost):
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, C3x):
c1, c2 = ch[f], args[0] c1, c2 = ch[f], args[0]
if c2 != no: # if not output if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8) c2 = make_divisible(c2 * gw, 8)


args = [c1, c2, *args[1:]] args = [c1, c2, *args[1:]]
if m in [BottleneckCSP, C3, C3TR, C3Ghost]:
if m in [BottleneckCSP, C3, C3TR, C3Ghost, C3x]:
args.insert(2, n) # number of repeats args.insert(2, n) # number of repeats
n = 1 n = 1
elif m is nn.BatchNorm2d: elif m is nn.BatchNorm2d:

Loading…
Cancel
Save