Browse Source

Add `autobatch` feature for best `batch-size` estimation (#5092)

* Autobatch

* fix mem

* fix mem2

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update train.py

* print result

* Cleanup print result

* swap fix in call

* to 64

* use total

* fix

* fix

* fix

* fix

* fix

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Cleanup printing

* Update final printout

* Update autobatch.py

* Update autobatch.py

* Update autobatch.py
modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
ca19df5f7f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 68 additions and 7 deletions
  1. +11
    -6
      train.py
  2. +56
    -0
      utils/autobatch.py
  3. +1
    -1
      utils/torch_utils.py

+ 11
- 6
train.py View File

@@ -36,6 +36,7 @@ import val # for end-of-epoch mAP
from models.experimental import attempt_load
from models.yolo import Model
from utils.autoanchor import check_anchors
from utils.autobatch import check_train_batch_size
from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \
@@ -131,6 +132,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
print(f'freezing {k}')
v.requires_grad = False

# Image size
gs = max(int(model.stride.max()), 32) # grid size (max stride)
imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple

# Batch size
if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
batch_size = check_train_batch_size(model, imgsz)

# Optimizer
nbs = 64 # nominal batch size
accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
@@ -190,11 +199,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

del ckpt, csd

# Image sizes
gs = max(int(model.stride.max()), 32) # grid size (max stride)
nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple

# DP mode
if cuda and RANK == -1 and torch.cuda.device_count() > 1:
logging.warning('DP not recommended, instead use torch.distributed.run for best DDP Multi-GPU results.\n'
@@ -242,6 +246,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)

# Model parameters
nl = model.model[-1].nl # number of detection layers (to scale hyps)
hyp['box'] *= 3. / nl # scale to layers
hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
@@ -440,7 +445,7 @@ def parse_opt(known=False):
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch.yaml', help='hyperparameters path')
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
parser.add_argument('--rect', action='store_true', help='rectangular training')
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')

+ 56
- 0
utils/autobatch.py View File

@@ -0,0 +1,56 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
Auto-batch utils
"""

from copy import deepcopy

import numpy as np
import torch
from torch.cuda import amp

from utils.general import colorstr
from utils.torch_utils import profile


def check_train_batch_size(model, imgsz=640):
# Check YOLOv5 training batch size
with amp.autocast():
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size


def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
# Automatically estimate best batch size to use `fraction` of available CUDA memory
# Usage:
# import torch
# from utils.autobatch import autobatch
# model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False)
# print(autobatch(model))

prefix = colorstr('autobatch: ')
print(f'{prefix}Computing optimal batch size for --imgsz {imgsz}')
device = next(model.parameters()).device # get model device
if device.type == 'cpu':
print(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
return batch_size

d = str(device).upper() # 'CUDA:0'
t = torch.cuda.get_device_properties(device).total_memory / 1024 ** 3 # (GB)
r = torch.cuda.memory_reserved(device) / 1024 ** 3 # (GB)
a = torch.cuda.memory_allocated(device) / 1024 ** 3 # (GB)
f = t - (r + a) # free inside reserved
print(f'{prefix}{d} {t:.3g}G total, {r:.3g}G reserved, {a:.3g}G allocated, {f:.3g}G free')

batch_sizes = [1, 2, 4, 8, 16]
try:
img = [torch.zeros(b, 3, imgsz, imgsz) for b in batch_sizes]
y = profile(img, model, n=3, device=device)
except Exception as e:
print(f'{prefix}{e}')

y = [x[2] for x in y if x] # memory [2]
batch_sizes = batch_sizes[:len(y)]
p = np.polyfit(batch_sizes, y, deg=1) # first degree polynomial fit
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
print(f'{prefix}Using colorstr(batch-size {b}) for {d} {t * fraction:.3g}G/{t:.3g}G ({fraction * 100:.0f}%)')
return b

+ 1
- 1
utils/torch_utils.py View File

@@ -126,7 +126,7 @@ def profile(input, ops, n=10, device=None):
_ = (sum([yi.sum() for yi in y]) if isinstance(y, list) else y).sum().backward()
t[2] = time_sync()
except Exception as e: # no backward method
print(e)
# print(e) # for debug
t[2] = float('nan')
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
tb += (t[2] - t[1]) * 1000 / n # ms per op backward

Loading…
Cancel
Save