Browse Source

Cat apriori to autolabels (#1484)

5.0
Glenn Jocher GitHub 3 years ago
parent
commit
95fa65339f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 8 deletions
  1. +2
    -1
      detect.py
  2. +7
    -6
      test.py
  3. +11
    -1
      utils/general.py

+ 2
- 1
detect.py View File

vid_writer.write(im0) vid_writer.write(im0)


if save_txt or save_img: if save_txt or save_img:
print('Results saved to %s' % save_dir)
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
print(f"Results saved to {save_dir}{s}")


print('Done. (%.3fs)' % (time.time() - t0)) print('Done. (%.3fs)' % (time.time() - t0))



+ 7
- 6
test.py View File

img /= 255.0 # 0 - 255 to 0.0 - 1.0 img /= 255.0 # 0 - 255 to 0.0 - 1.0
targets = targets.to(device) targets = targets.to(device)
nb, _, height, width = img.shape # batch size, channels, height, width nb, _, height, width = img.shape # batch size, channels, height, width
whwh = torch.Tensor([width, height, width, height]).to(device)
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device)


# Disable gradients
with torch.no_grad(): with torch.no_grad():
# Run model # Run model
t = time_synchronized() t = time_synchronized()
t0 += time_synchronized() - t t0 += time_synchronized() - t


# Compute loss # Compute loss
if training: # if model has loss hyperparameters
if training:
loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls


# Run NMS # Run NMS
t = time_synchronized() t = time_synchronized()
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres)
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_txt else [] # for autolabelling
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb)
t1 += time_synchronized() - t t1 += time_synchronized() - t


# Statistics per image # Statistics per image
tcls_tensor = labels[:, 0] tcls_tensor = labels[:, 0]


# target boxes # target boxes
tbox = xywh2xyxy(labels[:, 1:5]) * whwh
tbox = xywh2xyxy(labels[:, 1:5])
scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels


# Per target class # Per target class


# Return results # Return results
if not training: if not training:
print('Results saved to %s' % save_dir)
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
print(f"Results saved to {save_dir}{s}")
model.float() # for training model.float() # for training
maps = np.zeros(nc) + map maps = np.zeros(nc) + map
for i, c in enumerate(ap_class): for i, c in enumerate(ap_class):

+ 11
- 1
utils/general.py View File

return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter) return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)




def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, classes=None, agnostic=False):
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, classes=None, agnostic=False, labels=()):
"""Performs Non-Maximum Suppression (NMS) on inference results """Performs Non-Maximum Suppression (NMS) on inference results


Returns: Returns:
time_limit = 10.0 # seconds to quit after time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections redundant = True # require redundant detections
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS


t = time.time() t = time.time()
output = [torch.zeros(0, 6)] * prediction.shape[0] output = [torch.zeros(0, 6)] * prediction.shape[0]
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence x = x[xc[xi]] # confidence


# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
l = labels[xi]
v = torch.zeros((len(l), nc + 5), device=x.device)
v[:, :4] = l[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
x = torch.cat((x, v), 0)

# If none remain process next image # If none remain process next image
if not x.shape[0]: if not x.shape[0]:
continue continue

Loading…
Cancel
Save