Fix `torch.hub.list('ultralytics/yolov5')` pathlib bug (#3921)
This commit is contained in:
parent
87b094bcbc
commit
411842e058
10
hubconf.py
10
hubconf.py
|
|
@ -4,12 +4,9 @@ Usage:
|
||||||
import torch
|
import torch
|
||||||
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
|
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
|
||||||
"""
|
"""
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
FILE = Path(__file__).absolute()
|
|
||||||
|
|
||||||
|
|
||||||
def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
|
def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
|
||||||
"""Creates a specified YOLOv5 model
|
"""Creates a specified YOLOv5 model
|
||||||
|
|
@ -26,15 +23,18 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
|
||||||
Returns:
|
Returns:
|
||||||
YOLOv5 pytorch model
|
YOLOv5 pytorch model
|
||||||
"""
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from models.yolo import Model, attempt_load
|
from models.yolo import Model, attempt_load
|
||||||
from utils.general import check_requirements, set_logging
|
from utils.general import check_requirements, set_logging
|
||||||
from utils.google_utils import attempt_download
|
from utils.google_utils import attempt_download
|
||||||
from utils.torch_utils import select_device
|
from utils.torch_utils import select_device
|
||||||
|
|
||||||
check_requirements(requirements=FILE.parent / 'requirements.txt', exclude=('tensorboard', 'thop', 'opencv-python'))
|
file = Path(__file__).absolute()
|
||||||
|
check_requirements(requirements=file.parent / 'requirements.txt', exclude=('tensorboard', 'thop', 'opencv-python'))
|
||||||
set_logging(verbose=verbose)
|
set_logging(verbose=verbose)
|
||||||
|
|
||||||
save_dir = Path('') if str(name).endswith('.pt') else FILE.parent
|
save_dir = Path('') if str(name).endswith('.pt') else file.parent
|
||||||
path = (save_dir / name).with_suffix('.pt') # checkpoint path
|
path = (save_dir / name).with_suffix('.pt') # checkpoint path
|
||||||
try:
|
try:
|
||||||
device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device)
|
device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue