Update C3 module (#1705)

This commit is contained in:
Glenn Jocher 2020-12-15 22:13:08 -08:00 committed by GitHub
parent 7947c86b57
commit 799724108f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 22 deletions

View File

@ -29,7 +29,7 @@ class Conv(nn.Module):
super(Conv, self).__init__() super(Conv, self).__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2) self.bn = nn.BatchNorm2d(c2)
self.act = nn.Hardswish() if act else nn.Identity() self.act = nn.Hardswish() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
def forward(self, x): def forward(self, x):
return self.act(self.bn(self.conv(x))) return self.act(self.bn(self.conv(x)))
@ -70,6 +70,21 @@ class BottleneckCSP(nn.Module):
return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))) return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
class C3(nn.Module):
# CSP Bottleneck with 3 convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super(C3, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
# self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
def forward(self, x):
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
class SPP(nn.Module): class SPP(nn.Module):
# Spatial pyramid pooling layer used in YOLOv3-SPP # Spatial pyramid pooling layer used in YOLOv3-SPP
def __init__(self, c1, c2, k=(5, 9, 13)): def __init__(self, c1, c2, k=(5, 9, 13)):

View File

@ -22,25 +22,6 @@ class CrossConv(nn.Module):
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
class C3(nn.Module):
# Cross Convolution CSP
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super(C3, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
self.cv4 = Conv(2 * c_, c2, 1, 1)
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
self.act = nn.LeakyReLU(0.1, inplace=True)
self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
def forward(self, x):
y1 = self.cv3(self.m(self.cv1(x)))
y2 = self.cv2(x)
return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
class Sum(nn.Module): class Sum(nn.Module):
# Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
def __init__(self, n, weight=False): # n: number of inputs def __init__(self, n, weight=False): # n: number of inputs

View File

@ -11,8 +11,8 @@ import torch.nn as nn
sys.path.append('./') # to run '$ python *.py' files in subdirectories sys.path.append('./') # to run '$ python *.py' files in subdirectories
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, C3, Concat, NMS, autoShape
from models.experimental import MixConv2d, CrossConv, C3 from models.experimental import MixConv2d, CrossConv
from utils.autoanchor import check_anchor_order from utils.autoanchor import check_anchor_order
from utils.general import make_divisible, check_file, set_logging from utils.general import make_divisible, check_file, set_logging
from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \ from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \