Update C3 module (#1705)
This commit is contained in:
parent
7947c86b57
commit
799724108f
|
|
@ -29,7 +29,7 @@ class Conv(nn.Module):
|
||||||
super(Conv, self).__init__()
|
super(Conv, self).__init__()
|
||||||
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
|
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
|
||||||
self.bn = nn.BatchNorm2d(c2)
|
self.bn = nn.BatchNorm2d(c2)
|
||||||
self.act = nn.Hardswish() if act else nn.Identity()
|
self.act = nn.Hardswish() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.act(self.bn(self.conv(x)))
|
return self.act(self.bn(self.conv(x)))
|
||||||
|
|
@ -70,6 +70,21 @@ class BottleneckCSP(nn.Module):
|
||||||
return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
|
return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
|
||||||
|
|
||||||
|
|
||||||
|
class C3(nn.Module):
|
||||||
|
# CSP Bottleneck with 3 convolutions
|
||||||
|
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
||||||
|
super(C3, self).__init__()
|
||||||
|
c_ = int(c2 * e) # hidden channels
|
||||||
|
self.cv1 = Conv(c1, c_, 1, 1)
|
||||||
|
self.cv2 = Conv(c1, c_, 1, 1)
|
||||||
|
self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
|
||||||
|
self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
|
||||||
|
# self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
|
||||||
|
|
||||||
|
|
||||||
class SPP(nn.Module):
|
class SPP(nn.Module):
|
||||||
# Spatial pyramid pooling layer used in YOLOv3-SPP
|
# Spatial pyramid pooling layer used in YOLOv3-SPP
|
||||||
def __init__(self, c1, c2, k=(5, 9, 13)):
|
def __init__(self, c1, c2, k=(5, 9, 13)):
|
||||||
|
|
|
||||||
|
|
@ -22,25 +22,6 @@ class CrossConv(nn.Module):
|
||||||
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
||||||
|
|
||||||
|
|
||||||
class C3(nn.Module):
|
|
||||||
# Cross Convolution CSP
|
|
||||||
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
|
||||||
super(C3, self).__init__()
|
|
||||||
c_ = int(c2 * e) # hidden channels
|
|
||||||
self.cv1 = Conv(c1, c_, 1, 1)
|
|
||||||
self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
|
|
||||||
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
|
|
||||||
self.cv4 = Conv(2 * c_, c2, 1, 1)
|
|
||||||
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
|
|
||||||
self.act = nn.LeakyReLU(0.1, inplace=True)
|
|
||||||
self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
y1 = self.cv3(self.m(self.cv1(x)))
|
|
||||||
y2 = self.cv2(x)
|
|
||||||
return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
|
|
||||||
|
|
||||||
|
|
||||||
class Sum(nn.Module):
|
class Sum(nn.Module):
|
||||||
# Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
|
# Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
|
||||||
def __init__(self, n, weight=False): # n: number of inputs
|
def __init__(self, n, weight=False): # n: number of inputs
|
||||||
|
|
|
||||||
|
|
@ -11,8 +11,8 @@ import torch.nn as nn
|
||||||
sys.path.append('./') # to run '$ python *.py' files in subdirectories
|
sys.path.append('./') # to run '$ python *.py' files in subdirectories
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape
|
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, C3, Concat, NMS, autoShape
|
||||||
from models.experimental import MixConv2d, CrossConv, C3
|
from models.experimental import MixConv2d, CrossConv
|
||||||
from utils.autoanchor import check_anchor_order
|
from utils.autoanchor import check_anchor_order
|
||||||
from utils.general import make_divisible, check_file, set_logging
|
from utils.general import make_divisible, check_file, set_logging
|
||||||
from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
|
from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue