From 799724108fb4cc9cfc29029f13835ea668979f43 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 15 Dec 2020 22:13:08 -0800 Subject: [PATCH] Update C3 module (#1705) --- models/common.py | 17 ++++++++++++++++- models/experimental.py | 19 ------------------- models/yolo.py | 4 ++-- 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/models/common.py b/models/common.py index fcd87cb..c3b51a4 100644 --- a/models/common.py +++ b/models/common.py @@ -29,7 +29,7 @@ class Conv(nn.Module): super(Conv, self).__init__() self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) self.bn = nn.BatchNorm2d(c2) - self.act = nn.Hardswish() if act else nn.Identity() + self.act = nn.Hardswish() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) def forward(self, x): return self.act(self.bn(self.conv(x))) @@ -70,6 +70,21 @@ class BottleneckCSP(nn.Module): return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))) +class C3(nn.Module): + # CSP Bottleneck with 3 convolutions + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super(C3, self).__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2) + self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)]) + # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)]) + + def forward(self, x): + return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) + + class SPP(nn.Module): # Spatial pyramid pooling layer used in YOLOv3-SPP def __init__(self, c1, c2, k=(5, 9, 13)): diff --git a/models/experimental.py b/models/experimental.py index a2908a1..136e86d 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -22,25 +22,6 @@ class CrossConv(nn.Module): return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) -class C3(nn.Module): - # Cross Convolution CSP - def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion - super(C3, self).__init__() - c_ = int(c2 * e) # hidden channels - self.cv1 = Conv(c1, c_, 1, 1) - self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) - self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) - self.cv4 = Conv(2 * c_, c2, 1, 1) - self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) - self.act = nn.LeakyReLU(0.1, inplace=True) - self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)]) - - def forward(self, x): - y1 = self.cv3(self.m(self.cv1(x))) - y2 = self.cv2(x) - return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))) - - class Sum(nn.Module): # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 def __init__(self, n, weight=False): # n: number of inputs diff --git a/models/yolo.py b/models/yolo.py index dacb035..4ad44af 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -11,8 +11,8 @@ import torch.nn as nn sys.path.append('./') # to run '$ python *.py' files in subdirectories logger = logging.getLogger(__name__) -from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape -from models.experimental import MixConv2d, CrossConv, C3 +from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, C3, Concat, NMS, autoShape +from models.experimental import MixConv2d, CrossConv from utils.autoanchor import check_anchor_order from utils.general import make_divisible, check_file, set_logging from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \