Преглед изворни кода

PyTorch Hub load directly when possible (#2986)

modifyDataloader
Glenn Jocher GitHub пре 3 година
родитељ
комит
d08575ee5e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 измењених фајлова са 23 додато и 19 уклоњено
  1. +23
    -19
      hubconf.py

+ 23
- 19
hubconf.py Прегледај датотеку



import torch import torch


from models.yolo import Model
from models.yolo import Model, attempt_load
from utils.general import check_requirements, set_logging from utils.general import check_requirements, set_logging
from utils.google_utils import attempt_download from utils.google_utils import attempt_download
from utils.torch_utils import select_device from utils.torch_utils import select_device
pretrained (bool): load pretrained weights into the model pretrained (bool): load pretrained weights into the model
channels (int): number of input channels channels (int): number of input channels
classes (int): number of model classes classes (int): number of model classes
autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
verbose (bool): print all information to screen


Returns: Returns:
pytorch model
YOLOv5 pytorch model
""" """
set_logging(verbose=verbose)
fname = f'{name}.pt' # checkpoint filename
try: try:
set_logging(verbose=verbose)
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
model = Model(cfg, channels, classes)
if pretrained:
fname = f'{name}.pt' # checkpoint filename
attempt_download(fname) # download if not found locally
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
msd = model.state_dict() # model state_dict
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
model.load_state_dict(csd, strict=False) # load
if len(ckpt['model'].names) == classes:
model.names = ckpt['model'].names # set class names attribute
if autoshape:
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
if pretrained and channels == 3 and classes == 80:
model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model
else:
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
model = Model(cfg, channels, classes) # create model
if pretrained:
attempt_download(fname) # download if not found locally
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
msd = model.state_dict() # model state_dict
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
model.load_state_dict(csd, strict=False) # load
if len(ckpt['model'].names) == classes:
model.names = ckpt['model'].names # set class names attribute
if autoshape:
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
return model.to(device) return model.to(device)


except Exception as e: except Exception as e:
help_url = 'https://github.com/ultralytics/yolov5/issues/36' help_url = 'https://github.com/ultralytics/yolov5/issues/36'
s = 'Cache maybe be out of date, try force_reload=True. See %s for help.' % help_url
s = 'Cache may be out of date, try `force_reload=True`. See %s for help.' % help_url
raise Exception(s) from e raise Exception(s) from e





Loading…
Откажи
Сачувај