Update hubconf.py (#1210)
This commit is contained in:
parent
d3dad42256
commit
7f1640695b
20
hubconf.py
20
hubconf.py
|
|
@ -11,8 +11,11 @@ import os
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from models.yolo import Model
|
from models.yolo import Model
|
||||||
|
from utils.general import set_logging
|
||||||
from utils.google_utils import attempt_download
|
from utils.google_utils import attempt_download
|
||||||
|
|
||||||
|
set_logging()
|
||||||
|
|
||||||
|
|
||||||
def create(name, pretrained, channels, classes):
|
def create(name, pretrained, channels, classes):
|
||||||
"""Creates a specified YOLOv5 model
|
"""Creates a specified YOLOv5 model
|
||||||
|
|
@ -26,16 +29,19 @@ def create(name, pretrained, channels, classes):
|
||||||
Returns:
|
Returns:
|
||||||
pytorch model
|
pytorch model
|
||||||
"""
|
"""
|
||||||
config = os.path.join(os.path.dirname(__file__), 'models', '%s.yaml' % name) # model.yaml path
|
config = os.path.join(os.path.dirname(__file__), 'models', f'{name}.yaml') # model.yaml path
|
||||||
try:
|
try:
|
||||||
model = Model(config, channels, classes)
|
model = Model(config, channels, classes)
|
||||||
if pretrained:
|
if pretrained:
|
||||||
ckpt = '%s.pt' % name # checkpoint filename
|
fname = f'{name}.pt' # checkpoint filename
|
||||||
attempt_download(ckpt) # download if not found locally
|
attempt_download(fname) # download if not found locally
|
||||||
state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32
|
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
|
||||||
|
state_dict = ckpt['model'].float().state_dict() # to FP32
|
||||||
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
|
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
|
||||||
model.load_state_dict(state_dict, strict=False) # load
|
model.load_state_dict(state_dict, strict=False) # load
|
||||||
# model = model.autoshape() # cv2/PIL/np/torch inference: predictions = model(Image.open('image.jpg'))
|
if len(ckpt['model'].names) == classes:
|
||||||
|
model.names = ckpt['model'].names # set class names attribute
|
||||||
|
# model = model.autoshape() # for autoshaping of PIL/cv2/np inputs and NMS
|
||||||
return model
|
return model
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -98,3 +104,7 @@ def yolov5x(pretrained=False, channels=3, classes=80):
|
||||||
pytorch model
|
pytorch model
|
||||||
"""
|
"""
|
||||||
return create('yolov5x', pretrained, channels, classes)
|
return create('yolov5x', pretrained, channels, classes)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue