* PyTorch Hub models default to CUDA:0 if available * device as string bug fix5.0
from models.yolo import Model | from models.yolo import Model | ||||
from utils.general import set_logging | from utils.general import set_logging | ||||
from utils.google_utils import attempt_download | from utils.google_utils import attempt_download | ||||
from utils.torch_utils import select_device | |||||
dependencies = ['torch', 'yaml'] | dependencies = ['torch', 'yaml'] | ||||
set_logging() | set_logging() | ||||
model.names = ckpt['model'].names # set class names attribute | model.names = ckpt['model'].names # set class names attribute | ||||
if autoshape: | if autoshape: | ||||
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS | model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS | ||||
return model | |||||
device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available | |||||
return model.to(device) | |||||
except Exception as e: | except Exception as e: | ||||
help_url = 'https://github.com/ultralytics/yolov5/issues/36' | help_url = 'https://github.com/ultralytics/yolov5/issues/36' |
# Display cache | # Display cache | ||||
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total | nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total | ||||
if exists: | if exists: | ||||
d = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted" | |||||
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted" | |||||
tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results | tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results | ||||
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}' | assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}' | ||||
nc += 1 | nc += 1 | ||||
print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') | print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') | ||||
pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' for images and labels... " \ | |||||
pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \ | |||||
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" | f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" | ||||
if nf == 0: | if nf == 0: |
f"Use 'git pull' to update or 'git clone {url}' to download latest." | f"Use 'git pull' to update or 'git clone {url}' to download latest." | ||||
else: | else: | ||||
s = f'up to date with {url} ✅' | s = f'up to date with {url} ✅' | ||||
print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) | |||||
print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe | |||||
except Exception as e: | except Exception as e: | ||||
print(e) | print(e) | ||||
# PyTorch utils | # PyTorch utils | ||||
import logging | import logging | ||||
import math | import math | ||||
import os | import os | ||||
import platform | |||||
import subprocess | import subprocess | ||||
import time | import time | ||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
def select_device(device='', batch_size=None): | def select_device(device='', batch_size=None): | ||||
# device = 'cpu' or '0' or '0,1,2,3' | # device = 'cpu' or '0' or '0,1,2,3' | ||||
s = f'YOLOv5 {git_describe()} torch {torch.__version__} ' # string | |||||
s = f'YOLOv5 🚀 {git_describe()} torch {torch.__version__} ' # string | |||||
cpu = device.lower() == 'cpu' | cpu = device.lower() == 'cpu' | ||||
if cpu: | if cpu: | ||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False | ||||
else: | else: | ||||
s += 'CPU\n' | s += 'CPU\n' | ||||
logger.info(s) # skip a line | |||||
logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe | |||||
return torch.device('cuda:0' if cuda else 'cpu') | return torch.device('cuda:0' if cuda else 'cpu') | ||||