|
- import os
- import pynvml
- pynvml.nvmlInit()
-
-
- # 安装 pip install nvidia-ml-py3
- def usegpu(need_gpu_count=1):
-
- nouse=[]
- for index in range(pynvml.nvmlDeviceGetCount()):
- # 这里的0是GPU id
- handle = pynvml.nvmlDeviceGetHandleByIndex(index)
- meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
- used= meminfo.used/meminfo.total
- print(meminfo.used)
- print(meminfo.total)
- if used < 0.8:
- nouse.append(index)
- if len(nouse) >= need_gpu_count:
- os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, nouse[:need_gpu_count]))
- return nouse[:need_gpu_count]
- elif len(nouse)>0:
- os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, nouse))
- return len(nouse)
- else:
- return 0
-
- if __name__ == '__main__':
- gpus=usegpu(need_gpu_count=2)
- print(gpus)
- if gpus:
- print("use gpu ok")
- else:
- print("no gpu is valid")
|