PyTorch Hub load directly when possible (#2986)
This commit is contained in:
parent
9b91db6d1a
commit
d08575ee5e
42
hubconf.py
42
hubconf.py
|
|
@ -9,7 +9,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from models.yolo import Model
|
from models.yolo import Model, attempt_load
|
||||||
from utils.general import check_requirements, set_logging
|
from utils.general import check_requirements, set_logging
|
||||||
from utils.google_utils import attempt_download
|
from utils.google_utils import attempt_download
|
||||||
from utils.torch_utils import select_device
|
from utils.torch_utils import select_device
|
||||||
|
|
@ -26,33 +26,37 @@ def create(name, pretrained, channels, classes, autoshape, verbose):
|
||||||
pretrained (bool): load pretrained weights into the model
|
pretrained (bool): load pretrained weights into the model
|
||||||
channels (int): number of input channels
|
channels (int): number of input channels
|
||||||
classes (int): number of model classes
|
classes (int): number of model classes
|
||||||
|
autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
|
||||||
|
verbose (bool): print all information to screen
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
pytorch model
|
YOLOv5 pytorch model
|
||||||
"""
|
"""
|
||||||
|
set_logging(verbose=verbose)
|
||||||
|
fname = f'{name}.pt' # checkpoint filename
|
||||||
try:
|
try:
|
||||||
set_logging(verbose=verbose)
|
if pretrained and channels == 3 and classes == 80:
|
||||||
|
model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model
|
||||||
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
|
else:
|
||||||
model = Model(cfg, channels, classes)
|
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
|
||||||
if pretrained:
|
model = Model(cfg, channels, classes) # create model
|
||||||
fname = f'{name}.pt' # checkpoint filename
|
if pretrained:
|
||||||
attempt_download(fname) # download if not found locally
|
attempt_download(fname) # download if not found locally
|
||||||
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
|
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
|
||||||
msd = model.state_dict() # model state_dict
|
msd = model.state_dict() # model state_dict
|
||||||
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
||||||
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
|
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
|
||||||
model.load_state_dict(csd, strict=False) # load
|
model.load_state_dict(csd, strict=False) # load
|
||||||
if len(ckpt['model'].names) == classes:
|
if len(ckpt['model'].names) == classes:
|
||||||
model.names = ckpt['model'].names # set class names attribute
|
model.names = ckpt['model'].names # set class names attribute
|
||||||
if autoshape:
|
if autoshape:
|
||||||
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
|
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
|
||||||
device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
|
device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
|
||||||
return model.to(device)
|
return model.to(device)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
help_url = 'https://github.com/ultralytics/yolov5/issues/36'
|
help_url = 'https://github.com/ultralytics/yolov5/issues/36'
|
||||||
s = 'Cache maybe be out of date, try force_reload=True. See %s for help.' % help_url
|
s = 'Cache may be out of date, try `force_reload=True`. See %s for help.' % help_url
|
||||||
raise Exception(s) from e
|
raise Exception(s) from e
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue