Browse Source

fix `tf` conversion in new v6 models (#5153)

* fix `tf` conversion in new v6 (#5147)

* sort imports

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
modifyDataloader
Yoni Chechik GitHub 3 years ago
parent
commit
34da872ab6
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 2 deletions
  1. +18
    -2
      models/tf.py

+ 18
- 2
models/tf.py View File

@@ -28,7 +28,7 @@ import torch
import torch.nn as nn
from tensorflow import keras

from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, autopad, C3
from models.common import Bottleneck, BottleneckCSP, Concat, Conv, C3, DWConv, Focus, SPP, SPPF, autopad
from models.experimental import CrossConv, MixConv2d, attempt_load
from models.yolo import Detect
from utils.general import make_divisible, print_args, set_logging
@@ -183,6 +183,22 @@ class TFSPP(keras.layers.Layer):
return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))


class TFSPPF(keras.layers.Layer):
# Spatial pyramid pooling-Fast layer
def __init__(self, c1, c2, k=5, w=None):
super(TFSPPF, self).__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')

def call(self, inputs):
x = self.cv1(inputs)
y1 = self.m(x)
y2 = self.m(y1)
return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))


class TFDetect(keras.layers.Layer):
def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
super(TFDetect, self).__init__()
@@ -272,7 +288,7 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
pass

n = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
c1, c2 = ch[f], args[0]
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2


Loading…
Cancel
Save