|
|
@@ -91,9 +91,10 @@ class TFDWConv(keras.layers.Layer): |
|
|
|
def __init__(self, c1, c2, k=1, s=1, p=None, act=True, w=None): |
|
|
|
# ch_in, ch_out, weights, kernel, stride, padding, groups |
|
|
|
super().__init__() |
|
|
|
assert c1 == c2, f'TFDWConv() input={c1} must equal output={c2} channels' |
|
|
|
assert c2 % c1 == 0, f'TFDWConv() output={c2} must be a multiple of input={c1} channels' |
|
|
|
conv = keras.layers.DepthwiseConv2D( |
|
|
|
kernel_size=k, |
|
|
|
depth_multiplier=c2 // c1, |
|
|
|
strides=s, |
|
|
|
padding='SAME' if s == 1 else 'VALID', |
|
|
|
use_bias=not hasattr(w, 'bn'), |