Browse Source

Refactor `Detect()` anchors for ONNX <> OpenCV DNN compatibility (#4833)

* refactor anchors and anchor_grid in Detect Layer

* fix CI failures by adding compatibility

* fix tf failure

* fix different devices errors

* Cleanup

* fix anchors overwriting issue

* better refactoring

* Remove self.anchor_grid shape check (redundant with self.grid check)

Also PEP8 / 120 line width

* Convert _make_grid() from static to dynamic method

* Remove anchor_grid.to(device)

clone() should already clone to same device as self.anchors

* fix different devices error

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
modifyDataloader
Jebastin Nadar GitHub 3 years ago
parent
commit
9d75e42f98
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 24 additions and 16 deletions
  1. +2
    -0
      models/common.py
  2. +4
    -0
      models/experimental.py
  3. +1
    -1
      models/tf.py
  4. +13
    -9
      models/yolo.py
  5. +4
    -6
      utils/autoanchor.py

+ 2
- 0
models/common.py View File

@@ -295,6 +295,8 @@ class AutoShape(nn.Module):
m = self.model.model[-1] # Detect()
m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid))
if isinstance(m.anchor_grid, list):
m.anchor_grid = list(map(fn, m.anchor_grid))
return self

@torch.no_grad()

+ 4
- 0
models/experimental.py View File

@@ -102,6 +102,10 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True):
for m in model.modules():
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
m.inplace = inplace # pytorch 1.7.0 compatibility
if type(m) is Detect:
if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
delattr(m, 'anchor_grid')
setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
elif type(m) is Conv:
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility


+ 1
- 1
models/tf.py View File

@@ -193,7 +193,7 @@ class TFDetect(keras.layers.Layer):
self.na = len(anchors[0]) // 2 # number of anchors
self.grid = [tf.zeros(1)] * self.nl # init grid
self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
self.anchor_grid = tf.reshape(tf.convert_to_tensor(w.anchor_grid.numpy(), dtype=tf.float32),
self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]),
[self.nl, 1, -1, 1, 2])
self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
self.training = False # set to False after building model

+ 13
- 9
models/yolo.py View File

@@ -44,9 +44,8 @@ class Detect(nn.Module):
self.nl = len(anchors) # number of detection layers
self.na = len(anchors[0]) // 2 # number of anchors
self.grid = [torch.zeros(1)] * self.nl # init grid
a = torch.tensor(anchors).float().view(self.nl, -1, 2)
self.register_buffer('anchors', a) # shape(nl,na,2)
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
self.inplace = inplace # use in-place ops (e.g. slice assignment)

@@ -59,7 +58,7 @@ class Detect(nn.Module):

if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

y = x[i].sigmoid()
if self.inplace:
@@ -67,16 +66,19 @@ class Detect(nn.Module):
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
y = torch.cat((xy, wh, y[..., 4:]), -1)
z.append(y.view(bs, -1, self.no))

return x if self.training else (torch.cat(z, 1), x)

@staticmethod
def _make_grid(nx=20, ny=20):
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
def _make_grid(self, nx=20, ny=20, i=0):
d = self.anchors[i].device
yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)])
grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float()
anchor_grid = (self.anchors[i].clone() * self.stride[i]) \
.view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float()
return grid, anchor_grid


class Model(nn.Module):
@@ -239,6 +241,8 @@ class Model(nn.Module):
if isinstance(m, Detect):
m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid))
if isinstance(m.anchor_grid, list):
m.anchor_grid = list(map(fn, m.anchor_grid))
return self



+ 4
- 6
utils/autoanchor.py View File

@@ -15,13 +15,12 @@ from utils.general import colorstr

def check_anchor_order(m):
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
a = m.anchor_grid.prod(-1).view(-1) # anchor area
a = m.anchors.prod(-1).view(-1) # anchor area
da = a[-1] - a[0] # delta a
ds = m.stride[-1] - m.stride[0] # delta s
if da.sign() != ds.sign(): # same order
print('Reversing anchor order')
m.anchors[:] = m.anchors.flip(0)
m.anchor_grid[:] = m.anchor_grid.flip(0)


def check_anchors(dataset, model, thr=4.0, imgsz=640):
@@ -41,12 +40,12 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
bpr = (best > 1. / thr).float().mean() # best possible recall
return bpr, aat

anchors = m.anchor_grid.clone().cpu().view(-1, 2) # current anchors
bpr, aat = metric(anchors)
anchors = m.anchors.clone() * m.stride.to(m.anchors.device).view(-1, 1, 1) # current anchors
bpr, aat = metric(anchors.cpu().view(-1, 2))
print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
if bpr < 0.98: # threshold to recompute
print('. Attempting to improve anchors, please wait...')
na = m.anchor_grid.numel() // 2 # number of anchors
na = m.anchors.numel() // 2 # number of anchors
try:
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
except Exception as e:
@@ -54,7 +53,6 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
new_bpr = metric(anchors)[0]
if new_bpr > bpr: # replace anchors
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
m.anchor_grid[:] = anchors.clone().view_as(m.anchor_grid) # for inference
m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
check_anchor_order(m)
print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')

Loading…
Cancel
Save