diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 1628910..cddb173 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -53,7 +53,7 @@ def git_describe(path=Path(__file__).parent): # path must be a directory return '' # not a git repository -def select_device(device='', batch_size=None, newline=True): +def select_device(device='', batch_size=0, newline=True): # device = 'cpu' or '0' or '0,1,2,3' s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string device = str(device).strip().lower().replace('cuda:', '') # to string, 'cuda:0' to '0' @@ -68,7 +68,7 @@ def select_device(device='', batch_size=None, newline=True): if cuda: devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7 n = len(devices) # device count - if n > 1 and batch_size: # check batch_size is divisible by device_count + if n > 1 and batch_size > 0: # check batch_size is divisible by device_count assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}' space = ' ' * (len(s) + 1) for i, d in enumerate(devices):