|
|
@@ -27,8 +27,8 @@ import torch |
|
|
|
import torch.nn as nn |
|
|
|
from tensorflow import keras |
|
|
|
|
|
|
|
from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, Concat, Conv, DWConv, Focus, autopad |
|
|
|
from models.experimental import CrossConv, MixConv2d, attempt_load |
|
|
|
from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv, Focus, autopad |
|
|
|
from models.experimental import MixConv2d, attempt_load |
|
|
|
from models.yolo import Detect |
|
|
|
from utils.activations import SiLU |
|
|
|
from utils.general import LOGGER, make_divisible, print_args |
|
|
@@ -50,10 +50,13 @@ class TFBN(keras.layers.Layer): |
|
|
|
|
|
|
|
|
|
|
|
class TFPad(keras.layers.Layer): |
|
|
|
|
|
|
|
# Pad inputs in spatial dimensions 1 and 2 |
|
|
|
def __init__(self, pad): |
|
|
|
super().__init__() |
|
|
|
self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]]) |
|
|
|
if isinstance(pad, int): |
|
|
|
self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]]) |
|
|
|
else: # tuple/list |
|
|
|
self.pad = tf.constant([[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]]) |
|
|
|
|
|
|
|
def call(self, inputs): |
|
|
|
return tf.pad(inputs, self.pad, mode='constant', constant_values=0) |
|
|
@@ -65,10 +68,8 @@ class TFConv(keras.layers.Layer): |
|
|
|
# ch_in, ch_out, weights, kernel, stride, padding, groups |
|
|
|
super().__init__() |
|
|
|
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument" |
|
|
|
assert isinstance(k, int), "Convolution with multiple kernels are not allowed." |
|
|
|
# TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding) |
|
|
|
# see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch |
|
|
|
|
|
|
|
conv = keras.layers.Conv2D( |
|
|
|
filters=c2, |
|
|
|
kernel_size=k, |
|
|
@@ -90,8 +91,7 @@ class TFDWConv(keras.layers.Layer): |
|
|
|
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None): |
|
|
|
# ch_in, ch_out, weights, kernel, stride, padding, groups |
|
|
|
super().__init__() |
|
|
|
assert isinstance(k, int), "Convolution with multiple kernels are not allowed." |
|
|
|
|
|
|
|
assert g == c1 == c2, f'TFDWConv() groups={g} must equal input={c1} and output={c2} channels' |
|
|
|
conv = keras.layers.DepthwiseConv2D( |
|
|
|
kernel_size=k, |
|
|
|
strides=s, |
|
|
@@ -133,6 +133,19 @@ class TFBottleneck(keras.layers.Layer): |
|
|
|
return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs)) |
|
|
|
|
|
|
|
|
|
|
|
class TFCrossConv(keras.layers.Layer): |
|
|
|
# Cross Convolution |
|
|
|
def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False, w=None): |
|
|
|
super().__init__() |
|
|
|
c_ = int(c2 * e) # hidden channels |
|
|
|
self.cv1 = TFConv(c1, c_, (1, k), (1, s), w=w.cv1) |
|
|
|
self.cv2 = TFConv(c_, c2, (k, 1), (s, 1), g=g, w=w.cv2) |
|
|
|
self.add = shortcut and c1 == c2 |
|
|
|
|
|
|
|
def call(self, inputs): |
|
|
|
return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs)) |
|
|
|
|
|
|
|
|
|
|
|
class TFConv2d(keras.layers.Layer): |
|
|
|
# Substitution for PyTorch nn.Conv2D |
|
|
|
def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None): |
|
|
@@ -187,6 +200,22 @@ class TFC3(keras.layers.Layer): |
|
|
|
return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3)) |
|
|
|
|
|
|
|
|
|
|
|
class TFC3x(keras.layers.Layer): |
|
|
|
# 3 module with cross-convolutions |
|
|
|
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None): |
|
|
|
# ch_in, ch_out, number, shortcut, groups, expansion |
|
|
|
super().__init__() |
|
|
|
c_ = int(c2 * e) # hidden channels |
|
|
|
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1) |
|
|
|
self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2) |
|
|
|
self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3) |
|
|
|
self.m = keras.Sequential([ |
|
|
|
TFCrossConv(c_, c_, k=3, s=1, g=g, e=1.0, shortcut=shortcut, w=w.m[j]) for j in range(n)]) |
|
|
|
|
|
|
|
def call(self, inputs): |
|
|
|
return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3)) |
|
|
|
|
|
|
|
|
|
|
|
class TFSPP(keras.layers.Layer): |
|
|
|
# Spatial pyramid pooling layer used in YOLOv3-SPP |
|
|
|
def __init__(self, c1, c2, k=(5, 9, 13), w=None): |
|
|
@@ -310,12 +339,12 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3) |
|
|
|
pass |
|
|
|
|
|
|
|
n = max(round(n * gd), 1) if n > 1 else n # depth gain |
|
|
|
if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]: |
|
|
|
if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3, C3x]: |
|
|
|
c1, c2 = ch[f], args[0] |
|
|
|
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2 |
|
|
|
|
|
|
|
args = [c1, c2, *args[1:]] |
|
|
|
if m in [BottleneckCSP, C3]: |
|
|
|
if m in [BottleneckCSP, C3, C3x]: |
|
|
|
args.insert(2, n) |
|
|
|
n = 1 |
|
|
|
elif m is nn.BatchNorm2d: |