* PyTorch Hub models default to CUDA:0 if available * device as string bug fix5.0
@@ -12,6 +12,7 @@ import torch | |||
from models.yolo import Model | |||
from utils.general import set_logging | |||
from utils.google_utils import attempt_download | |||
from utils.torch_utils import select_device | |||
dependencies = ['torch', 'yaml'] | |||
set_logging() | |||
@@ -43,7 +44,8 @@ def create(name, pretrained, channels, classes, autoshape): | |||
model.names = ckpt['model'].names # set class names attribute | |||
if autoshape: | |||
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS | |||
return model | |||
device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available | |||
return model.to(device) | |||
except Exception as e: | |||
help_url = 'https://github.com/ultralytics/yolov5/issues/36' |
@@ -385,7 +385,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing | |||
# Display cache | |||
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total | |||
if exists: | |||
d = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted" | |||
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted" | |||
tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results | |||
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}' | |||
@@ -485,7 +485,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing | |||
nc += 1 | |||
print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') | |||
pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' for images and labels... " \ | |||
pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \ | |||
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" | |||
if nf == 0: |
@@ -79,7 +79,7 @@ def check_git_status(): | |||
f"Use 'git pull' to update or 'git clone {url}' to download latest." | |||
else: | |||
s = f'up to date with {url} ✅' | |||
print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) | |||
print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe | |||
except Exception as e: | |||
print(e) | |||
@@ -1,8 +1,8 @@ | |||
# PyTorch utils | |||
import logging | |||
import math | |||
import os | |||
import platform | |||
import subprocess | |||
import time | |||
from contextlib import contextmanager | |||
@@ -53,7 +53,7 @@ def git_describe(): | |||
def select_device(device='', batch_size=None): | |||
# device = 'cpu' or '0' or '0,1,2,3' | |||
s = f'YOLOv5 {git_describe()} torch {torch.__version__} ' # string | |||
s = f'YOLOv5 🚀 {git_describe()} torch {torch.__version__} ' # string | |||
cpu = device.lower() == 'cpu' | |||
if cpu: | |||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False | |||
@@ -73,7 +73,7 @@ def select_device(device='', batch_size=None): | |||
else: | |||
s += 'CPU\n' | |||
logger.info(s) # skip a line | |||
logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe | |||
return torch.device('cuda:0' if cuda else 'cpu') | |||