@@ -311,7 +311,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non | |||
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] | |||
# Filter by class | |||
if classes: | |||
if classes is not None: | |||
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] | |||
# Apply finite constraint |