|
|
@@ -72,11 +72,12 @@ def select_device(device='', batch_size=None): |
|
|
|
|
|
|
|
cuda = not cpu and torch.cuda.is_available() |
|
|
|
if cuda: |
|
|
|
n = torch.cuda.device_count() |
|
|
|
if n > 1 and batch_size: # check that batch_size is compatible with device_count |
|
|
|
devices = device.split(',') if device else range(torch.cuda.device_count()) # i.e. 0,1,6,7 |
|
|
|
n = len(devices) # device count |
|
|
|
if n > 1 and batch_size: # check batch_size is divisible by device_count |
|
|
|
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}' |
|
|
|
space = ' ' * len(s) |
|
|
|
for i, d in enumerate(device.split(',') if device else range(n)): |
|
|
|
for i, d in enumerate(devices): |
|
|
|
p = torch.cuda.get_device_properties(i) |
|
|
|
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB |
|
|
|
else: |