Browse Source

Tensor initialization on device improvements (#6959)

* Update common.py speed improvements

Eliminate .to() ops where possible for reduced data transfer overhead. Primarily affects warmup and PyTorch Hub inference.

* Updates

* Updates

* Update detect.py

* Update val.py
modifyDataloader
Glenn Jocher GitHub 2 years ago
parent
commit
701e1177ac
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 4 deletions
  1. +1
    -1
      models/common.py
  2. +3
    -3
      val.py

+ 1
- 1
models/common.py View File

@@ -466,7 +466,7 @@ class DetectMultiBackend(nn.Module):
# Warmup model by running inference once
if self.pt or self.jit or self.onnx or self.engine: # warmup types
if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models
im = torch.zeros(*imgsz).to(self.device).type(torch.half if self.fp16 else torch.float) # input image
im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
self.forward(im) # warmup

@staticmethod

+ 3
- 3
val.py View File

@@ -87,7 +87,7 @@ def process_batch(detections, labels, iouv):
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
# matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
matches = torch.Tensor(matches).to(iouv.device)
matches = torch.from_numpy(matches).to(iouv.device)
correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv
return correct

@@ -155,7 +155,7 @@ def run(data,
cuda = device.type != 'cpu'
is_coco = isinstance(data.get('val'), str) and data['val'].endswith('coco/val2017.txt') # COCO dataset
nc = 1 if single_cls else int(data['nc']) # number of classes
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
iouv = torch.linspace(0.5, 0.95, 10, device=device) # iou vector for mAP@0.5:0.95
niou = iouv.numel()

# Dataloader
@@ -196,7 +196,7 @@ def run(data,
loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls

# NMS
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
t3 = time_sync()
out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)

Loading…
Cancel
Save