@@ -137,7 +137,8 @@ def detect(save_img=False): | |||
vid_writer.write(im0) | |||
if save_txt or save_img: | |||
print('Results saved to %s' % save_dir) | |||
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' | |||
print(f"Results saved to {save_dir}{s}") | |||
print('Done. (%.3fs)' % (time.time() - t0)) | |||
@@ -101,9 +101,8 @@ def test(data, | |||
img /= 255.0 # 0 - 255 to 0.0 - 1.0 | |||
targets = targets.to(device) | |||
nb, _, height, width = img.shape # batch size, channels, height, width | |||
whwh = torch.Tensor([width, height, width, height]).to(device) | |||
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) | |||
# Disable gradients | |||
with torch.no_grad(): | |||
# Run model | |||
t = time_synchronized() | |||
@@ -111,12 +110,13 @@ def test(data, | |||
t0 += time_synchronized() - t | |||
# Compute loss | |||
if training: # if model has loss hyperparameters | |||
if training: | |||
loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls | |||
# Run NMS | |||
t = time_synchronized() | |||
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres) | |||
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_txt else [] # for autolabelling | |||
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb) | |||
t1 += time_synchronized() - t | |||
# Statistics per image | |||
@@ -174,7 +174,7 @@ def test(data, | |||
tcls_tensor = labels[:, 0] | |||
# target boxes | |||
tbox = xywh2xyxy(labels[:, 1:5]) * whwh | |||
tbox = xywh2xyxy(labels[:, 1:5]) | |||
scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels | |||
# Per target class | |||
@@ -264,7 +264,8 @@ def test(data, | |||
# Return results | |||
if not training: | |||
print('Results saved to %s' % save_dir) | |||
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' | |||
print(f"Results saved to {save_dir}{s}") | |||
model.float() # for training | |||
maps = np.zeros(nc) + map | |||
for i, c in enumerate(ap_class): |
@@ -263,7 +263,7 @@ def wh_iou(wh1, wh2): | |||
return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter) | |||
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, classes=None, agnostic=False): | |||
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, classes=None, agnostic=False, labels=()): | |||
"""Performs Non-Maximum Suppression (NMS) on inference results | |||
Returns: | |||
@@ -279,6 +279,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, | |||
time_limit = 10.0 # seconds to quit after | |||
redundant = True # require redundant detections | |||
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) | |||
merge = False # use merge-NMS | |||
t = time.time() | |||
output = [torch.zeros(0, 6)] * prediction.shape[0] | |||
@@ -287,6 +288,15 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, | |||
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height | |||
x = x[xc[xi]] # confidence | |||
# Cat apriori labels if autolabelling | |||
if labels and len(labels[xi]): | |||
l = labels[xi] | |||
v = torch.zeros((len(l), nc + 5), device=x.device) | |||
v[:, :4] = l[:, 1:5] # box | |||
v[:, 4] = 1.0 # conf | |||
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls | |||
x = torch.cat((x, v), 0) | |||
# If none remain process next image | |||
if not x.shape[0]: | |||
continue |