* refactor anchors and anchor_grid in Detect Layer * fix CI failures by adding compatibility * fix tf failure * fix different devices errors * Cleanup * fix anchors overwriting issue * better refactoring * Remove self.anchor_grid shape check (redundant with self.grid check) Also PEP8 / 120 line width * Convert _make_grid() from static to dynamic method * Remove anchor_grid.to(device) clone() should already clone to same device as self.anchors * fix different devices error Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>modifyDataloader
@@ -295,6 +295,8 @@ class AutoShape(nn.Module): | |||
m = self.model.model[-1] # Detect() | |||
m.stride = fn(m.stride) | |||
m.grid = list(map(fn, m.grid)) | |||
if isinstance(m.anchor_grid, list): | |||
m.anchor_grid = list(map(fn, m.anchor_grid)) | |||
return self | |||
@torch.no_grad() |
@@ -102,6 +102,10 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True): | |||
for m in model.modules(): | |||
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]: | |||
m.inplace = inplace # pytorch 1.7.0 compatibility | |||
if type(m) is Detect: | |||
if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility | |||
delattr(m, 'anchor_grid') | |||
setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl) | |||
elif type(m) is Conv: | |||
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility | |||
@@ -193,7 +193,7 @@ class TFDetect(keras.layers.Layer): | |||
self.na = len(anchors[0]) // 2 # number of anchors | |||
self.grid = [tf.zeros(1)] * self.nl # init grid | |||
self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32) | |||
self.anchor_grid = tf.reshape(tf.convert_to_tensor(w.anchor_grid.numpy(), dtype=tf.float32), | |||
self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]), | |||
[self.nl, 1, -1, 1, 2]) | |||
self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)] | |||
self.training = False # set to False after building model |
@@ -44,9 +44,8 @@ class Detect(nn.Module): | |||
self.nl = len(anchors) # number of detection layers | |||
self.na = len(anchors[0]) // 2 # number of anchors | |||
self.grid = [torch.zeros(1)] * self.nl # init grid | |||
a = torch.tensor(anchors).float().view(self.nl, -1, 2) | |||
self.register_buffer('anchors', a) # shape(nl,na,2) | |||
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2) | |||
self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid | |||
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2) | |||
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv | |||
self.inplace = inplace # use in-place ops (e.g. slice assignment) | |||
@@ -59,7 +58,7 @@ class Detect(nn.Module): | |||
if not self.training: # inference | |||
if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic: | |||
self.grid[i] = self._make_grid(nx, ny).to(x[i].device) | |||
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i) | |||
y = x[i].sigmoid() | |||
if self.inplace: | |||
@@ -67,16 +66,19 @@ class Detect(nn.Module): | |||
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh | |||
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953 | |||
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy | |||
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh | |||
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh | |||
y = torch.cat((xy, wh, y[..., 4:]), -1) | |||
z.append(y.view(bs, -1, self.no)) | |||
return x if self.training else (torch.cat(z, 1), x) | |||
@staticmethod | |||
def _make_grid(nx=20, ny=20): | |||
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) | |||
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float() | |||
def _make_grid(self, nx=20, ny=20, i=0): | |||
d = self.anchors[i].device | |||
yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)]) | |||
grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float() | |||
anchor_grid = (self.anchors[i].clone() * self.stride[i]) \ | |||
.view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float() | |||
return grid, anchor_grid | |||
class Model(nn.Module): | |||
@@ -239,6 +241,8 @@ class Model(nn.Module): | |||
if isinstance(m, Detect): | |||
m.stride = fn(m.stride) | |||
m.grid = list(map(fn, m.grid)) | |||
if isinstance(m.anchor_grid, list): | |||
m.anchor_grid = list(map(fn, m.anchor_grid)) | |||
return self | |||
@@ -15,13 +15,12 @@ from utils.general import colorstr | |||
def check_anchor_order(m): | |||
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary | |||
a = m.anchor_grid.prod(-1).view(-1) # anchor area | |||
a = m.anchors.prod(-1).view(-1) # anchor area | |||
da = a[-1] - a[0] # delta a | |||
ds = m.stride[-1] - m.stride[0] # delta s | |||
if da.sign() != ds.sign(): # same order | |||
print('Reversing anchor order') | |||
m.anchors[:] = m.anchors.flip(0) | |||
m.anchor_grid[:] = m.anchor_grid.flip(0) | |||
def check_anchors(dataset, model, thr=4.0, imgsz=640): | |||
@@ -41,12 +40,12 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640): | |||
bpr = (best > 1. / thr).float().mean() # best possible recall | |||
return bpr, aat | |||
anchors = m.anchor_grid.clone().cpu().view(-1, 2) # current anchors | |||
bpr, aat = metric(anchors) | |||
anchors = m.anchors.clone() * m.stride.to(m.anchors.device).view(-1, 1, 1) # current anchors | |||
bpr, aat = metric(anchors.cpu().view(-1, 2)) | |||
print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='') | |||
if bpr < 0.98: # threshold to recompute | |||
print('. Attempting to improve anchors, please wait...') | |||
na = m.anchor_grid.numel() // 2 # number of anchors | |||
na = m.anchors.numel() // 2 # number of anchors | |||
try: | |||
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False) | |||
except Exception as e: | |||
@@ -54,7 +53,6 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640): | |||
new_bpr = metric(anchors)[0] | |||
if new_bpr > bpr: # replace anchors | |||
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors) | |||
m.anchor_grid[:] = anchors.clone().view_as(m.anchor_grid) # for inference | |||
m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss | |||
check_anchor_order(m) | |||
print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.') |