|
|
@@ -62,8 +62,7 @@ def select_device(device='', batch_size=0, newline=True): |
|
|
|
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \ |
|
|
|
f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)" |
|
|
|
|
|
|
|
cuda = not cpu and torch.cuda.is_available() |
|
|
|
if cuda: |
|
|
|
if not cpu and torch.cuda.is_available(): # prefer GPU if available |
|
|
|
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7 |
|
|
|
n = len(devices) # device count |
|
|
|
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count |
|
|
@@ -72,15 +71,18 @@ def select_device(device='', batch_size=0, newline=True): |
|
|
|
for i, d in enumerate(devices): |
|
|
|
p = torch.cuda.get_device_properties(i) |
|
|
|
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB |
|
|
|
elif mps: |
|
|
|
arg = 'cuda:0' |
|
|
|
elif not cpu and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available(): # prefer MPS if available |
|
|
|
s += 'MPS\n' |
|
|
|
else: |
|
|
|
arg = 'mps' |
|
|
|
else: # revert to CPU |
|
|
|
s += 'CPU\n' |
|
|
|
arg = 'cpu' |
|
|
|
|
|
|
|
if not newline: |
|
|
|
s = s.rstrip() |
|
|
|
LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe |
|
|
|
return torch.device('cuda:0' if cuda else 'mps' if mps else 'cpu') |
|
|
|
return torch.device(arg) |
|
|
|
|
|
|
|
|
|
|
|
def time_sync(): |