YOLOv5 AWS Inferentia Inplace compatibility updates (#2953)
* Added flag to enable/disable all inplace and assignment operations * Removed shape print statements * Scope Detect/Model import to avoid circular dependency * PEP8 * create _descale_pred() * replace lost space * replace list with tuple Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
955eea8b96
commit
41f5cc5637
|
|
@ -110,7 +110,9 @@ class Ensemble(nn.ModuleList):
|
||||||
return y, None # inference, train output
|
return y, None # inference, train output
|
||||||
|
|
||||||
|
|
||||||
def attempt_load(weights, map_location=None):
|
def attempt_load(weights, map_location=None, inplace=True):
|
||||||
|
from models.yolo import Detect, Model
|
||||||
|
|
||||||
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
||||||
model = Ensemble()
|
model = Ensemble()
|
||||||
for w in weights if isinstance(weights, list) else [weights]:
|
for w in weights if isinstance(weights, list) else [weights]:
|
||||||
|
|
@ -120,8 +122,8 @@ def attempt_load(weights, map_location=None):
|
||||||
|
|
||||||
# Compatibility updates
|
# Compatibility updates
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
|
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
|
||||||
m.inplace = True # pytorch 1.7.0 compatibility
|
m.inplace = inplace # pytorch 1.7.0 compatibility
|
||||||
elif type(m) is Conv:
|
elif type(m) is Conv:
|
||||||
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
|
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ class Detect(nn.Module):
|
||||||
stride = None # strides computed during build
|
stride = None # strides computed during build
|
||||||
export = False # onnx export
|
export = False # onnx export
|
||||||
|
|
||||||
def __init__(self, nc=80, anchors=(), ch=()): # detection layer
|
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
|
||||||
super(Detect, self).__init__()
|
super(Detect, self).__init__()
|
||||||
self.nc = nc # number of classes
|
self.nc = nc # number of classes
|
||||||
self.no = nc + 5 # number of outputs per anchor
|
self.no = nc + 5 # number of outputs per anchor
|
||||||
|
|
@ -37,6 +37,7 @@ class Detect(nn.Module):
|
||||||
self.register_buffer('anchors', a) # shape(nl,na,2)
|
self.register_buffer('anchors', a) # shape(nl,na,2)
|
||||||
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
|
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
|
||||||
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
||||||
|
self.inplace = inplace # use in-place ops (e.g. slice assignment)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# x = x.copy() # for profiling
|
# x = x.copy() # for profiling
|
||||||
|
|
@ -52,8 +53,13 @@ class Detect(nn.Module):
|
||||||
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
|
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
|
||||||
|
|
||||||
y = x[i].sigmoid()
|
y = x[i].sigmoid()
|
||||||
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
if self.inplace:
|
||||||
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
||||||
|
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
||||||
|
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
|
||||||
|
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
||||||
|
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
||||||
|
y = torch.cat((xy, wh, y[..., 4:]), -1)
|
||||||
z.append(y.view(bs, -1, self.no))
|
z.append(y.view(bs, -1, self.no))
|
||||||
|
|
||||||
return x if self.training else (torch.cat(z, 1), x)
|
return x if self.training else (torch.cat(z, 1), x)
|
||||||
|
|
@ -85,12 +91,14 @@ class Model(nn.Module):
|
||||||
self.yaml['anchors'] = round(anchors) # override yaml value
|
self.yaml['anchors'] = round(anchors) # override yaml value
|
||||||
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
|
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
|
||||||
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
|
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
|
||||||
|
self.inplace = self.yaml.get('inplace', True)
|
||||||
# logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
|
# logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
|
||||||
|
|
||||||
# Build strides, anchors
|
# Build strides, anchors
|
||||||
m = self.model[-1] # Detect()
|
m = self.model[-1] # Detect()
|
||||||
if isinstance(m, Detect):
|
if isinstance(m, Detect):
|
||||||
s = 256 # 2x min stride
|
s = 256 # 2x min stride
|
||||||
|
m.inplace = self.inplace
|
||||||
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
|
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
|
||||||
m.anchors /= m.stride.view(-1, 1, 1)
|
m.anchors /= m.stride.view(-1, 1, 1)
|
||||||
check_anchor_order(m)
|
check_anchor_order(m)
|
||||||
|
|
@ -105,24 +113,23 @@ class Model(nn.Module):
|
||||||
|
|
||||||
def forward(self, x, augment=False, profile=False):
|
def forward(self, x, augment=False, profile=False):
|
||||||
if augment:
|
if augment:
|
||||||
img_size = x.shape[-2:] # height, width
|
return self.forward_augment(x) # augmented inference, None
|
||||||
s = [1, 0.83, 0.67] # scales
|
|
||||||
f = [None, 3, None] # flips (2-ud, 3-lr)
|
|
||||||
y = [] # outputs
|
|
||||||
for si, fi in zip(s, f):
|
|
||||||
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
|
|
||||||
yi = self.forward_once(xi)[0] # forward
|
|
||||||
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
|
|
||||||
yi[..., :4] /= si # de-scale
|
|
||||||
if fi == 2:
|
|
||||||
yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
|
|
||||||
elif fi == 3:
|
|
||||||
yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
|
|
||||||
y.append(yi)
|
|
||||||
return torch.cat(y, 1), None # augmented inference, train
|
|
||||||
else:
|
else:
|
||||||
return self.forward_once(x, profile) # single-scale inference, train
|
return self.forward_once(x, profile) # single-scale inference, train
|
||||||
|
|
||||||
|
def forward_augment(self, x):
|
||||||
|
img_size = x.shape[-2:] # height, width
|
||||||
|
s = [1, 0.83, 0.67] # scales
|
||||||
|
f = [None, 3, None] # flips (2-ud, 3-lr)
|
||||||
|
y = [] # outputs
|
||||||
|
for si, fi in zip(s, f):
|
||||||
|
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
|
||||||
|
yi = self.forward_once(xi)[0] # forward
|
||||||
|
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
|
||||||
|
yi = self._descale_pred(yi, fi, si, img_size)
|
||||||
|
y.append(yi)
|
||||||
|
return torch.cat(y, 1), None # augmented inference, train
|
||||||
|
|
||||||
def forward_once(self, x, profile=False):
|
def forward_once(self, x, profile=False):
|
||||||
y, dt = [], [] # outputs
|
y, dt = [], [] # outputs
|
||||||
for m in self.model:
|
for m in self.model:
|
||||||
|
|
@ -146,6 +153,23 @@ class Model(nn.Module):
|
||||||
logger.info('%.1fms total' % sum(dt))
|
logger.info('%.1fms total' % sum(dt))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def _descale_pred(self, p, flips, scale, img_size):
|
||||||
|
# de-scale predictions following augmented inference (inverse operation)
|
||||||
|
if self.inplace:
|
||||||
|
p[..., :4] /= scale # de-scale
|
||||||
|
if flips == 2:
|
||||||
|
p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
|
||||||
|
elif flips == 3:
|
||||||
|
p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
|
||||||
|
else:
|
||||||
|
x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
|
||||||
|
if flips == 2:
|
||||||
|
y = img_size[0] - y # de-flip ud
|
||||||
|
elif flips == 3:
|
||||||
|
x = img_size[1] - x # de-flip lr
|
||||||
|
p = torch.cat((x, y, wh, p[..., 4:]), -1)
|
||||||
|
return p
|
||||||
|
|
||||||
def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
|
def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
|
||||||
# https://arxiv.org/abs/1708.02002 section 3.3
|
# https://arxiv.org/abs/1708.02002 section 3.3
|
||||||
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
|
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue