Browse Source

Improved `detect.py` timing (#4741)

* Improved detect.py timing

* Eliminate 1 time_sync() call

* Inference-only time

* dash

* #Save section

* Cleanup
modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
7af1b4c266
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 18 deletions
  1. +14
    -10
      detect.py
  2. +8
    -8
      val.py

+ 14
- 10
detect.py View File



import argparse import argparse
import sys import sys
import time
from pathlib import Path from pathlib import Path


import cv2 import cv2
# Run inference # Run inference
if pt and device.type != 'cpu': if pt and device.type != 'cpu':
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters()))) # run once model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters()))) # run once
t0 = time.time()
dt, seen = [0.0, 0.0, 0.0], 0
for path, img, im0s, vid_cap in dataset: for path, img, im0s, vid_cap in dataset:
t1 = time_sync()
if onnx: if onnx:
img = img.astype('float32') img = img.astype('float32')
else: else:
img = img / 255.0 # 0 - 255 to 0.0 - 1.0 img = img / 255.0 # 0 - 255 to 0.0 - 1.0
if len(img.shape) == 3: if len(img.shape) == 3:
img = img[None] # expand for batch dim img = img[None] # expand for batch dim
t2 = time_sync()
dt[0] += t2 - t1


# Inference # Inference
t1 = time_sync()
if pt: if pt:
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
pred = model(img, augment=augment, visualize=visualize)[0] pred = model(img, augment=augment, visualize=visualize)[0]
pred[..., 2] *= imgsz[1] # w pred[..., 2] *= imgsz[1] # w
pred[..., 3] *= imgsz[0] # h pred[..., 3] *= imgsz[0] # h
pred = torch.tensor(pred) pred = torch.tensor(pred)
t3 = time_sync()
dt[1] += t3 - t2


# NMS # NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
t2 = time_sync()
dt[2] += time_sync() - t3


# Second-stage classifier (optional) # Second-stage classifier (optional)
if classify: if classify:
pred = apply_classifier(pred, modelc, img, im0s) pred = apply_classifier(pred, modelc, img, im0s)


# Process predictions # Process predictions
for i, det in enumerate(pred): # detections per image
for i, det in enumerate(pred): # per image
seen += 1
if webcam: # batch_size >= 1 if webcam: # batch_size >= 1
p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), dataset.count p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), dataset.count
else: else:
if save_crop: if save_crop:
save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True) save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)


# Print time (inference + NMS)
print(f'{s}Done. ({t2 - t1:.3f}s)')
# Print time (inference-only)
print(f'{s}Done. ({t3 - t2:.3f}s)')


# Stream results # Stream results
im0 = annotator.result() im0 = annotator.result()
vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
vid_writer[i].write(im0) vid_writer[i].write(im0)


# Print results
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
if save_txt or save_img: if save_txt or save_img:
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
print(f"Results saved to {colorstr('bold', save_dir)}{s}") print(f"Results saved to {colorstr('bold', save_dir)}{s}")

if update: if update:
strip_optimizer(weights) # update model (to fix SourceChangeWarning) strip_optimizer(weights) # update model (to fix SourceChangeWarning)


print(f'Done. ({time.time() - t0:.3f}s)')



def parse_opt(): def parse_opt():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

+ 8
- 8
val.py View File

names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)} names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
class_map = coco80_to_coco91_class() if is_coco else list(range(1000)) class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95') s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
p, r, f1, mp, mr, map50, map, t0, t1, t2 = 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
dt, p, r, f1, mp, mr, map50, map = [0.0, 0.0, 0.0], 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
loss = torch.zeros(3, device=device) loss = torch.zeros(3, device=device)
jdict, stats, ap, ap_class = [], [], [], [] jdict, stats, ap, ap_class = [], [], [], []
for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)): for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
t_ = time_sync()
t1 = time_sync()
img = img.to(device, non_blocking=True) img = img.to(device, non_blocking=True)
img = img.half() if half else img.float() # uint8 to fp16/32 img = img.half() if half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0 img /= 255.0 # 0 - 255 to 0.0 - 1.0
targets = targets.to(device) targets = targets.to(device)
nb, _, height, width = img.shape # batch size, channels, height, width nb, _, height, width = img.shape # batch size, channels, height, width
t = time_sync()
t0 += t - t_
t2 = time_sync()
dt[0] += t2 - t1


# Run model # Run model
out, train_out = model(img, augment=augment) # inference and training outputs out, train_out = model(img, augment=augment) # inference and training outputs
t1 += time_sync() - t
dt[1] += time_sync() - t2


# Compute loss # Compute loss
if compute_loss: if compute_loss:
# Run NMS # Run NMS
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
t = time_sync()
t3 = time_sync()
out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls) out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
t2 += time_sync() - t
dt[2] += time_sync() - t3


# Statistics per image # Statistics per image
for si, pred in enumerate(out): for si, pred in enumerate(out):
print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i])) print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))


# Print speeds # Print speeds
t = tuple(x / seen * 1E3 for x in (t0, t1, t2)) # speeds per image
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
if not training: if not training:
shape = (batch_size, 3, imgsz, imgsz) shape = (batch_size, 3, imgsz, imgsz)
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t) print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)

Loading…
Cancel
Save