PyTorch Hub load directly when possible (#2986)
This commit is contained in:
parent
9b91db6d1a
commit
d08575ee5e
20
hubconf.py
20
hubconf.py
|
|
@ -9,7 +9,7 @@ from pathlib import Path
|
|||
|
||||
import torch
|
||||
|
||||
from models.yolo import Model
|
||||
from models.yolo import Model, attempt_load
|
||||
from utils.general import check_requirements, set_logging
|
||||
from utils.google_utils import attempt_download
|
||||
from utils.torch_utils import select_device
|
||||
|
|
@ -26,17 +26,21 @@ def create(name, pretrained, channels, classes, autoshape, verbose):
|
|||
pretrained (bool): load pretrained weights into the model
|
||||
channels (int): number of input channels
|
||||
classes (int): number of model classes
|
||||
autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
|
||||
verbose (bool): print all information to screen
|
||||
|
||||
Returns:
|
||||
pytorch model
|
||||
YOLOv5 pytorch model
|
||||
"""
|
||||
try:
|
||||
set_logging(verbose=verbose)
|
||||
|
||||
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
|
||||
model = Model(cfg, channels, classes)
|
||||
if pretrained:
|
||||
fname = f'{name}.pt' # checkpoint filename
|
||||
try:
|
||||
if pretrained and channels == 3 and classes == 80:
|
||||
model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model
|
||||
else:
|
||||
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
|
||||
model = Model(cfg, channels, classes) # create model
|
||||
if pretrained:
|
||||
attempt_download(fname) # download if not found locally
|
||||
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
|
||||
msd = model.state_dict() # model state_dict
|
||||
|
|
@ -52,7 +56,7 @@ def create(name, pretrained, channels, classes, autoshape, verbose):
|
|||
|
||||
except Exception as e:
|
||||
help_url = 'https://github.com/ultralytics/yolov5/issues/36'
|
||||
s = 'Cache maybe be out of date, try force_reload=True. See %s for help.' % help_url
|
||||
s = 'Cache may be out of date, try `force_reload=True`. See %s for help.' % help_url
|
||||
raise Exception(s) from e
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue