PyTorch Hub load directly when possible (#2986)

This commit is contained in:
Glenn Jocher 2021-04-30 14:59:51 +02:00 committed by GitHub
parent 9b91db6d1a
commit d08575ee5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 23 additions and 19 deletions

View File

@ -9,7 +9,7 @@ from pathlib import Path
import torch
from models.yolo import Model
from models.yolo import Model, attempt_load
from utils.general import check_requirements, set_logging
from utils.google_utils import attempt_download
from utils.torch_utils import select_device
@ -26,17 +26,21 @@ def create(name, pretrained, channels, classes, autoshape, verbose):
pretrained (bool): load pretrained weights into the model
channels (int): number of input channels
classes (int): number of model classes
autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
verbose (bool): print all information to screen
Returns:
pytorch model
YOLOv5 pytorch model
"""
try:
set_logging(verbose=verbose)
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
model = Model(cfg, channels, classes)
if pretrained:
fname = f'{name}.pt' # checkpoint filename
try:
if pretrained and channels == 3 and classes == 80:
model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model
else:
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
model = Model(cfg, channels, classes) # create model
if pretrained:
attempt_download(fname) # download if not found locally
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
msd = model.state_dict() # model state_dict
@ -52,7 +56,7 @@ def create(name, pretrained, channels, classes, autoshape, verbose):
except Exception as e:
help_url = 'https://github.com/ultralytics/yolov5/issues/36'
s = 'Cache maybe be out of date, try force_reload=True. See %s for help.' % help_url
s = 'Cache may be out of date, try `force_reload=True`. See %s for help.' % help_url
raise Exception(s) from e