* yolotr * transformer block * Remove bias in Transformer * Remove C3T * Remove a deprecated class * put the 2nd LayerNorm into the 2nd residual block * move example model to models/hub, rename to -transformer * Add module comments and TODOs * Remove LN in Transformer * Add comments for Transformer * Solve the problem of MA with DDP * cleanup * cleanup find_unused_parameters * PEP8 reformat Co-authored-by: DingYiwei <846414640@qq.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>5.0
@@ -43,6 +43,52 @@ class Conv(nn.Module): | |||
return self.act(self.conv(x)) | |||
class TransformerLayer(nn.Module): | |||
# Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance) | |||
def __init__(self, c, num_heads): | |||
super().__init__() | |||
self.q = nn.Linear(c, c, bias=False) | |||
self.k = nn.Linear(c, c, bias=False) | |||
self.v = nn.Linear(c, c, bias=False) | |||
self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads) | |||
self.fc1 = nn.Linear(c, c, bias=False) | |||
self.fc2 = nn.Linear(c, c, bias=False) | |||
def forward(self, x): | |||
x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x | |||
x = self.fc2(self.fc1(x)) + x | |||
return x | |||
class TransformerBlock(nn.Module): | |||
# Vision Transformer https://arxiv.org/abs/2010.11929 | |||
def __init__(self, c1, c2, num_heads, num_layers): | |||
super().__init__() | |||
self.conv = None | |||
if c1 != c2: | |||
self.conv = Conv(c1, c2) | |||
self.linear = nn.Linear(c2, c2) # learnable position embedding | |||
self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)]) | |||
self.c2 = c2 | |||
def forward(self, x): | |||
if self.conv is not None: | |||
x = self.conv(x) | |||
b, _, w, h = x.shape | |||
p = x.flatten(2) | |||
p = p.unsqueeze(0) | |||
p = p.transpose(0, 3) | |||
p = p.squeeze(3) | |||
e = self.linear(p) | |||
x = p + e | |||
x = self.tr(x) | |||
x = x.unsqueeze(3) | |||
x = x.transpose(0, 3) | |||
x = x.reshape(b, self.c2, w, h) | |||
return x | |||
class Bottleneck(nn.Module): | |||
# Standard bottleneck | |||
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion | |||
@@ -90,6 +136,14 @@ class C3(nn.Module): | |||
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) | |||
class C3TR(C3): | |||
# C3 module with TransformerBlock() | |||
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): | |||
super().__init__(c1, c2, n, shortcut, g, e) | |||
c_ = int(c2 * e) | |||
self.m = TransformerBlock(c_, c_, 4, n) | |||
class SPP(nn.Module): | |||
# Spatial pyramid pooling layer used in YOLOv3-SPP | |||
def __init__(self, c1, c2, k=(5, 9, 13)): |
@@ -0,0 +1,48 @@ | |||
# parameters | |||
nc: 80 # number of classes | |||
depth_multiple: 0.33 # model depth multiple | |||
width_multiple: 0.50 # layer channel multiple | |||
# anchors | |||
anchors: | |||
- [10,13, 16,30, 33,23] # P3/8 | |||
- [30,61, 62,45, 59,119] # P4/16 | |||
- [116,90, 156,198, 373,326] # P5/32 | |||
# YOLOv5 backbone | |||
backbone: | |||
# [from, number, module, args] | |||
[[-1, 1, Focus, [64, 3]], # 0-P1/2 | |||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4 | |||
[-1, 3, C3, [128]], | |||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8 | |||
[-1, 9, C3, [256]], | |||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16 | |||
[-1, 9, C3, [512]], | |||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 | |||
[-1, 1, SPP, [1024, [5, 9, 13]]], | |||
[-1, 3, C3TR, [1024, False]], # 9 <-------- C3TR() Transformer module | |||
] | |||
# YOLOv5 head | |||
head: | |||
[[-1, 1, Conv, [512, 1, 1]], | |||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], | |||
[[-1, 6], 1, Concat, [1]], # cat backbone P4 | |||
[-1, 3, C3, [512, False]], # 13 | |||
[-1, 1, Conv, [256, 1, 1]], | |||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], | |||
[[-1, 4], 1, Concat, [1]], # cat backbone P3 | |||
[-1, 3, C3, [256, False]], # 17 (P3/8-small) | |||
[-1, 1, Conv, [256, 3, 2]], | |||
[[-1, 14], 1, Concat, [1]], # cat head P4 | |||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium) | |||
[-1, 1, Conv, [512, 3, 2]], | |||
[[-1, 10], 1, Concat, [1]], # cat head P5 | |||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large) | |||
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) | |||
] |
@@ -215,13 +215,13 @@ def parse_model(d, ch): # model_dict, input_channels(3) | |||
n = max(round(n * gd), 1) if n > 1 else n # depth gain | |||
if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, | |||
C3]: | |||
C3, C3TR]: | |||
c1, c2 = ch[f], args[0] | |||
if c2 != no: # if not output | |||
c2 = make_divisible(c2 * gw, 8) | |||
args = [c1, c2, *args[1:]] | |||
if m in [BottleneckCSP, C3]: | |||
if m in [BottleneckCSP, C3, C3TR]: | |||
args.insert(2, n) # number of repeats | |||
n = 1 | |||
elif m is nn.BatchNorm2d: |
@@ -218,7 +218,9 @@ def train(hyp, opt, device, tb_writer=None): | |||
# DDP mode | |||
if cuda and rank != -1: | |||
model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank) | |||
model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank, | |||
# nn.MultiheadAttention incompatibility with DDP https://github.com/pytorch/pytorch/issues/26698 | |||
find_unused_parameters=any(isinstance(layer, nn.MultiheadAttention) for layer in model.modules())) | |||
# Model parameters | |||
hyp['box'] *= 3. / nl # scale to layers |