Browse Source

Add `SPPF()` layer (#4420)

* Add `SPPF()` layer

* Cleanup

* Add credit
modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
01cdb7671b
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 5 deletions
  1. +19
    -1
      models/common.py
  2. +6
    -4
      models/yolo.py

+ 19
- 1
models/common.py View File

@@ -161,7 +161,7 @@ class C3Ghost(C3):


class SPP(nn.Module):
# Spatial pyramid pooling layer used in YOLOv3-SPP
# Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
def __init__(self, c1, c2, k=(5, 9, 13)):
super().__init__()
c_ = c1 // 2 # hidden channels
@@ -176,6 +176,24 @@ class SPP(nn.Module):
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))


class SPPF(nn.Module):
# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * 4, c2, 1, 1)
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)

def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
y1 = self.m(x)
y2 = self.m(y1)
return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))


class Focus(nn.Module):
# Focus wh information into c-space
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups

+ 6
- 4
models/yolo.py View File

@@ -237,8 +237,8 @@ def parse_model(d, ch): # model_dict, input_channels(3)
pass

n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
C3, C3TR, C3SPP, C3Ghost]:
if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]:
c1, c2 = ch[f], args[0]
if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8)
@@ -279,6 +279,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--profile', action='store_true', help='profile model speed')
opt = parser.parse_args()
opt.cfg = check_file(opt.cfg) # check file
set_logging()
@@ -289,8 +290,9 @@ if __name__ == '__main__':
model.train()

# Profile
# img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 320, 320).to(device)
# y = model(img, profile=True)
if opt.profile:
img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
y = model(img, profile=True)

# Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898)
# from torch.utils.tensorboard import SummaryWriter

Loading…
Cancel
Save