@@ -154,7 +154,7 @@ if __name__ == '__main__': | |||
with torch.no_grad(): | |||
if opt.update: # update all models (to fix SourceChangeWarning) | |||
for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', 'yolov3-spp.pt']: | |||
for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']: | |||
detect() | |||
strip_optimizer(opt.weights) | |||
else: |
@@ -90,9 +90,9 @@ class Model(nn.Module): | |||
yi = self.forward_once(xi)[0] # forward | |||
# cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save | |||
yi[..., :4] /= si # de-scale | |||
if fi is 2: | |||
if fi == 2: | |||
yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud | |||
elif fi is 3: | |||
elif fi == 3: | |||
yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr | |||
y.append(yi) | |||
return torch.cat(y, 1), None # augmented inference, train | |||
@@ -148,6 +148,7 @@ class Model(nn.Module): | |||
print('Fusing layers... ', end='') | |||
for m in self.model.modules(): | |||
if type(m) is Conv: | |||
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability | |||
m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv | |||
m.bn = None # remove batchnorm | |||
m.forward = m.fuseforward # update forward |
@@ -148,8 +148,8 @@ def test(data, | |||
# Per target class | |||
for cls in torch.unique(tcls_tensor): | |||
ti = (cls == tcls_tensor).nonzero().view(-1) # prediction indices | |||
pi = (cls == pred[:, 5]).nonzero().view(-1) # target indices | |||
ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1) # prediction indices | |||
pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1) # target indices | |||
# Search for detections | |||
if pi.shape[0]: | |||
@@ -157,7 +157,7 @@ def test(data, | |||
ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices | |||
# Append detections | |||
for j in (ious > iouv[0]).nonzero(): | |||
for j in (ious > iouv[0]).nonzero(as_tuple=False): | |||
d = ti[i[j]] # detected target | |||
if d not in detected: | |||
detected.append(d) |