|
|
@@ -40,10 +40,10 @@ def torch_distributed_zero_first(local_rank: int): |
|
|
|
|
|
|
|
|
|
|
|
def device_count(): |
|
|
|
# Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Only works on Linux. |
|
|
|
assert platform.system() == 'Linux', 'device_count() function only works on Linux' |
|
|
|
# Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Supports Linux and Windows |
|
|
|
assert platform.system() in ('Linux', 'Windows'), 'device_count() only supported on Linux or Windows' |
|
|
|
try: |
|
|
|
cmd = 'nvidia-smi -L | wc -l' |
|
|
|
cmd = 'nvidia-smi -L | wc -l' if platform.system() == 'Linux' else 'nvidia-smi -L | find /c /v ""' # Windows |
|
|
|
return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]) |
|
|
|
except Exception: |
|
|
|
return 0 |