Hub models `map_location=device` (#3894)

* Hub models `map_location=device`

* cleanup
This commit is contained in:
Glenn Jocher 2021-07-05 16:20:46 +02:00 committed by GitHub
parent 8930e22cce
commit 6a3ee7cf03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 5 deletions

View File

@ -36,13 +36,15 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
fname = Path(name).with_suffix('.pt') # checkpoint filename fname = Path(name).with_suffix('.pt') # checkpoint filename
try: try:
device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device)
if pretrained and channels == 3 and classes == 80: if pretrained and channels == 3 and classes == 80:
model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model model = attempt_load(fname, map_location=device) # download/load FP32 model
else: else:
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
model = Model(cfg, channels, classes) # create model model = Model(cfg, channels, classes) # create model
if pretrained: if pretrained:
ckpt = torch.load(attempt_download(fname), map_location=torch.device('cpu')) # load ckpt = torch.load(attempt_download(fname), map_location=device) # load
msd = model.state_dict() # model state_dict msd = model.state_dict() # model state_dict
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
@ -51,7 +53,6 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
model.names = ckpt['model'].names # set class names attribute model.names = ckpt['model'].names # set class names attribute
if autoshape: if autoshape:
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
device = select_device('0' if torch.cuda.is_available() else 'cpu') if device is None else torch.device(device)
return model.to(device) return model.to(device)
except Exception as e: except Exception as e:

View File

@ -2,7 +2,6 @@
import datetime import datetime
import logging import logging
import math
import os import os
import platform import platform
import subprocess import subprocess
@ -11,6 +10,7 @@ from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import math
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import torch.distributed as dist import torch.distributed as dist
@ -64,7 +64,8 @@ def git_describe(path=Path(__file__).parent): # path must be a directory
def select_device(device='', batch_size=None): def select_device(device='', batch_size=None):
# device = 'cpu' or '0' or '0,1,2,3' # device = 'cpu' or '0' or '0,1,2,3'
s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
cpu = device.lower() == 'cpu' device = str(device).strip().lower().replace('cuda:', '') # to string, 'cuda:0' to '0'
cpu = device == 'cpu'
if cpu: if cpu:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
elif device: # non-cpu device requested elif device: # non-cpu device requested