* Update `set_logging()` * Update export.py * pre-commit fixes * Update LoadImages * Update LoadStreams * Update print_args * Single LOGGER definition * yolo.py fix Co-authored-by: pre-commit <pre-commit@example.com>modifyDataloader
@@ -25,8 +25,7 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative | |||
from models.experimental import attempt_load | |||
from utils.datasets import LoadImages, LoadStreams | |||
from utils.general import apply_classifier, check_img_size, check_imshow, check_requirements, check_suffix, colorstr, \ | |||
increment_path, non_max_suppression, print_args, save_one_box, scale_coords, set_logging, \ | |||
strip_optimizer, xyxy2xywh | |||
increment_path, non_max_suppression, print_args, save_one_box, scale_coords, strip_optimizer, xyxy2xywh, LOGGER | |||
from utils.plots import Annotator, colors | |||
from utils.torch_utils import load_classifier, select_device, time_sync | |||
@@ -68,7 +67,6 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) | |||
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir | |||
# Initialize | |||
set_logging() | |||
device = select_device(device) | |||
half &= device.type != 'cpu' # half precision only supported on CUDA | |||
@@ -132,7 +130,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) | |||
if pt and device.type != 'cpu': | |||
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters()))) # run once | |||
dt, seen = [0.0, 0.0, 0.0], 0 | |||
for path, img, im0s, vid_cap in dataset: | |||
for path, img, im0s, vid_cap, s in dataset: | |||
t1 = time_sync() | |||
if onnx: | |||
img = img.astype('float32') | |||
@@ -191,9 +189,10 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) | |||
for i, det in enumerate(pred): # per image | |||
seen += 1 | |||
if webcam: # batch_size >= 1 | |||
p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), dataset.count | |||
p, im0, frame = path[i], im0s[i].copy(), dataset.count | |||
s += f'{i}: ' | |||
else: | |||
p, s, im0, frame = path, '', im0s.copy(), getattr(dataset, 'frame', 0) | |||
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0) | |||
p = Path(p) # to Path | |||
save_path = str(save_dir / p.name) # img.jpg | |||
@@ -227,7 +226,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) | |||
save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True) | |||
# Print time (inference-only) | |||
print(f'{s}Done. ({t3 - t2:.3f}s)') | |||
LOGGER.info(f'{s}Done. ({t3 - t2:.3f}s)') | |||
# Stream results | |||
im0 = annotator.result() | |||
@@ -256,10 +255,10 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) | |||
# Print results | |||
t = tuple(x / seen * 1E3 for x in dt) # speeds per image | |||
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t) | |||
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t) | |||
if save_txt or save_img: | |||
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' | |||
print(f"Results saved to {colorstr('bold', save_dir)}{s}") | |||
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}") | |||
if update: | |||
strip_optimizer(weights) # update model (to fix SourceChangeWarning) | |||
@@ -42,23 +42,23 @@ from models.experimental import attempt_load | |||
from models.yolo import Detect | |||
from utils.activations import SiLU | |||
from utils.datasets import LoadImages | |||
from utils.general import colorstr, check_dataset, check_img_size, check_requirements, file_size, print_args, \ | |||
set_logging, url2file | |||
from utils.general import check_dataset, check_img_size, check_requirements, colorstr, file_size, print_args, \ | |||
url2file, LOGGER | |||
from utils.torch_utils import select_device | |||
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')): | |||
# YOLOv5 TorchScript model export | |||
try: | |||
print(f'\n{prefix} starting export with torch {torch.__version__}...') | |||
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...') | |||
f = file.with_suffix('.torchscript.pt') | |||
ts = torch.jit.trace(model, im, strict=False) | |||
(optimize_for_mobile(ts) if optimize else ts).save(f) | |||
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||
except Exception as e: | |||
print(f'{prefix} export failure: {e}') | |||
LOGGER.info(f'{prefix} export failure: {e}') | |||
def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')): | |||
@@ -67,7 +67,7 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst | |||
check_requirements(('onnx',)) | |||
import onnx | |||
print(f'\n{prefix} starting export with onnx {onnx.__version__}...') | |||
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...') | |||
f = file.with_suffix('.onnx') | |||
torch.onnx.export(model, im, f, verbose=False, opset_version=opset, | |||
@@ -82,7 +82,7 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst | |||
# Checks | |||
model_onnx = onnx.load(f) # load onnx model | |||
onnx.checker.check_model(model_onnx) # check onnx model | |||
# print(onnx.helper.printable_graph(model_onnx.graph)) # print | |||
# LOGGER.info(onnx.helper.printable_graph(model_onnx.graph)) # print | |||
# Simplify | |||
if simplify: | |||
@@ -90,7 +90,7 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst | |||
check_requirements(('onnx-simplifier',)) | |||
import onnxsim | |||
print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...') | |||
LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...') | |||
model_onnx, check = onnxsim.simplify( | |||
model_onnx, | |||
dynamic_input_shape=dynamic, | |||
@@ -98,11 +98,11 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst | |||
assert check, 'assert check failed' | |||
onnx.save(model_onnx, f) | |||
except Exception as e: | |||
print(f'{prefix} simplifier failure: {e}') | |||
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||
print(f"{prefix} run --dynamic ONNX model inference with: 'python detect.py --weights {f}'") | |||
LOGGER.info(f'{prefix} simplifier failure: {e}') | |||
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||
LOGGER.info(f"{prefix} run --dynamic ONNX model inference with: 'python detect.py --weights {f}'") | |||
except Exception as e: | |||
print(f'{prefix} export failure: {e}') | |||
LOGGER.info(f'{prefix} export failure: {e}') | |||
def export_coreml(model, im, file, prefix=colorstr('CoreML:')): | |||
@@ -112,7 +112,7 @@ def export_coreml(model, im, file, prefix=colorstr('CoreML:')): | |||
check_requirements(('coremltools',)) | |||
import coremltools as ct | |||
print(f'\n{prefix} starting export with coremltools {ct.__version__}...') | |||
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...') | |||
f = file.with_suffix('.mlmodel') | |||
model.train() # CoreML exports should be placed in model.train() mode | |||
@@ -120,9 +120,9 @@ def export_coreml(model, im, file, prefix=colorstr('CoreML:')): | |||
ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255.0, bias=[0, 0, 0])]) | |||
ct_model.save(f) | |||
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||
except Exception as e: | |||
print(f'\n{prefix} export failure: {e}') | |||
LOGGER.info(f'\n{prefix} export failure: {e}') | |||
return ct_model | |||
@@ -137,7 +137,7 @@ def export_saved_model(model, im, file, dynamic, | |||
from tensorflow import keras | |||
from models.tf import TFModel, TFDetect | |||
print(f'\n{prefix} starting export with tensorflow {tf.__version__}...') | |||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') | |||
f = str(file).replace('.pt', '_saved_model') | |||
batch_size, ch, *imgsz = list(im.shape) # BCHW | |||
@@ -151,9 +151,9 @@ def export_saved_model(model, im, file, dynamic, | |||
keras_model.summary() | |||
keras_model.save(f, save_format='tf') | |||
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||
except Exception as e: | |||
print(f'\n{prefix} export failure: {e}') | |||
LOGGER.info(f'\n{prefix} export failure: {e}') | |||
return keras_model | |||
@@ -164,7 +164,7 @@ def export_pb(keras_model, im, file, prefix=colorstr('TensorFlow GraphDef:')): | |||
import tensorflow as tf | |||
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 | |||
print(f'\n{prefix} starting export with tensorflow {tf.__version__}...') | |||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') | |||
f = file.with_suffix('.pb') | |||
m = tf.function(lambda x: keras_model(x)) # full model | |||
@@ -173,9 +173,9 @@ def export_pb(keras_model, im, file, prefix=colorstr('TensorFlow GraphDef:')): | |||
frozen_func.graph.as_graph_def() | |||
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False) | |||
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||
except Exception as e: | |||
print(f'\n{prefix} export failure: {e}') | |||
LOGGER.info(f'\n{prefix} export failure: {e}') | |||
def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('TensorFlow Lite:')): | |||
@@ -184,7 +184,7 @@ def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('Te | |||
import tensorflow as tf | |||
from models.tf import representative_dataset_gen | |||
print(f'\n{prefix} starting export with tensorflow {tf.__version__}...') | |||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') | |||
batch_size, ch, *imgsz = list(im.shape) # BCHW | |||
f = str(file).replace('.pt', '-fp16.tflite') | |||
@@ -204,10 +204,10 @@ def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('Te | |||
tflite_model = converter.convert() | |||
open(f, "wb").write(tflite_model) | |||
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||
except Exception as e: | |||
print(f'\n{prefix} export failure: {e}') | |||
LOGGER.info(f'\n{prefix} export failure: {e}') | |||
def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')): | |||
@@ -217,7 +217,7 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')): | |||
import re | |||
import tensorflowjs as tfjs | |||
print(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...') | |||
LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...') | |||
f = str(file).replace('.pt', '_web_model') # js dir | |||
f_pb = file.with_suffix('.pb') # *.pb path | |||
f_json = f + '/model.json' # *.json path | |||
@@ -240,9 +240,9 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')): | |||
json) | |||
j.write(subst) | |||
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||
except Exception as e: | |||
print(f'\n{prefix} export failure: {e}') | |||
LOGGER.info(f'\n{prefix} export failure: {e}') | |||
@torch.no_grad() | |||
@@ -297,7 +297,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' | |||
for _ in range(2): | |||
y = model(im) # dry runs | |||
print(f"\n{colorstr('PyTorch:')} starting from {file} ({file_size(file):.1f} MB)") | |||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} ({file_size(file):.1f} MB)") | |||
# Exports | |||
if 'torchscript' in include: | |||
@@ -322,9 +322,9 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' | |||
export_tfjs(model, im, file) | |||
# Finish | |||
print(f'\nExport complete ({time.time() - t:.2f}s)' | |||
f"\nResults saved to {colorstr('bold', file.parent.resolve())}" | |||
f'\nVisualize with https://netron.app') | |||
LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)' | |||
f"\nResults saved to {colorstr('bold', file.parent.resolve())}" | |||
f'\nVisualize with https://netron.app') | |||
def parse_opt(): | |||
@@ -355,7 +355,6 @@ def parse_opt(): | |||
def main(opt): | |||
set_logging() | |||
run(**vars(opt)) | |||
@@ -31,11 +31,9 @@ from tensorflow import keras | |||
from models.common import Bottleneck, BottleneckCSP, Concat, Conv, C3, DWConv, Focus, SPP, SPPF, autopad | |||
from models.experimental import CrossConv, MixConv2d, attempt_load | |||
from models.yolo import Detect | |||
from utils.general import make_divisible, print_args, set_logging | |||
from utils.general import make_divisible, print_args, LOGGER | |||
from utils.activations import SiLU | |||
LOGGER = logging.getLogger(__name__) | |||
class TFBN(keras.layers.Layer): | |||
# TensorFlow BatchNormalization wrapper | |||
@@ -336,7 +334,7 @@ class TFModel: | |||
# Define model | |||
if nc and nc != self.yaml['nc']: | |||
print(f"Overriding {cfg} nc={self.yaml['nc']} with nc={nc}") | |||
LOGGER.info(f"Overriding {cfg} nc={self.yaml['nc']} with nc={nc}") | |||
self.yaml['nc'] = nc # override yaml value | |||
self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model, imgsz=imgsz) | |||
@@ -457,7 +455,6 @@ def parse_opt(): | |||
def main(opt): | |||
set_logging() | |||
run(**vars(opt)) | |||
@@ -20,7 +20,7 @@ if str(ROOT) not in sys.path: | |||
from models.common import * | |||
from models.experimental import * | |||
from utils.autoanchor import check_anchor_order | |||
from utils.general import check_yaml, make_divisible, print_args, set_logging, check_version | |||
from utils.general import check_version, check_yaml, make_divisible, print_args, LOGGER | |||
from utils.plots import feature_visualization | |||
from utils.torch_utils import copy_attr, fuse_conv_and_bn, initialize_weights, model_info, scale_img, \ | |||
select_device, time_sync | |||
@@ -30,8 +30,6 @@ try: | |||
except ImportError: | |||
thop = None | |||
LOGGER = logging.getLogger(__name__) | |||
class Detect(nn.Module): | |||
stride = None # strides computed during build | |||
@@ -311,7 +309,6 @@ if __name__ == '__main__': | |||
opt = parser.parse_args() | |||
opt.cfg = check_yaml(opt.cfg) # check YAML | |||
print_args(FILE.stem, opt) | |||
set_logging() | |||
device = select_device(opt.device) | |||
# Create model |
@@ -40,7 +40,7 @@ from utils.autobatch import check_train_batch_size | |||
from utils.datasets import create_dataloader | |||
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \ | |||
strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \ | |||
check_file, check_yaml, check_suffix, print_args, print_mutation, set_logging, one_cycle, colorstr, methods | |||
check_file, check_yaml, check_suffix, print_args, print_mutation, one_cycle, colorstr, methods, LOGGER | |||
from utils.downloads import attempt_download | |||
from utils.loss import ComputeLoss | |||
from utils.plots import plot_labels, plot_evolve | |||
@@ -51,7 +51,6 @@ from utils.metrics import fitness | |||
from utils.loggers import Loggers | |||
from utils.callbacks import Callbacks | |||
LOGGER = logging.getLogger(__name__) | |||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html | |||
RANK = int(os.getenv('RANK', -1)) | |||
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) | |||
@@ -129,7 +128,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |||
for k, v in model.named_parameters(): | |||
v.requires_grad = True # train all layers | |||
if any(x in k for x in freeze): | |||
print(f'freezing {k}') | |||
LOGGER.info(f'freezing {k}') | |||
v.requires_grad = False | |||
# Image size | |||
@@ -485,7 +484,6 @@ def parse_opt(known=False): | |||
def main(opt, callbacks=Callbacks()): | |||
# Checks | |||
set_logging(RANK) | |||
if RANK in [-1, 0]: | |||
print_args(FILE.stem, opt) | |||
check_git_status() | |||
@@ -609,9 +607,9 @@ def main(opt, callbacks=Callbacks()): | |||
# Plot results | |||
plot_evolve(evolve_csv) | |||
print(f'Hyperparameter evolution finished\n' | |||
f"Results saved to {colorstr('bold', save_dir)}\n" | |||
f'Use best hyperparameters example: $ python train.py --hyp {evolve_yaml}') | |||
LOGGER.info(f'Hyperparameter evolution finished\n' | |||
f"Results saved to {colorstr('bold', save_dir)}\n" | |||
f'Use best hyperparameters example: $ python train.py --hyp {evolve_yaml}') | |||
def run(**kwargs): |
@@ -28,7 +28,7 @@ from tqdm import tqdm | |||
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective | |||
from utils.general import check_dataset, check_requirements, check_yaml, clean_str, segments2boxes, \ | |||
xywh2xyxy, xywhn2xyxy, xyxy2xywhn, xyn2xy | |||
xywh2xyxy, xywhn2xyxy, xyxy2xywhn, xyn2xy, LOGGER | |||
from utils.torch_utils import torch_distributed_zero_first | |||
# Parameters | |||
@@ -210,14 +210,14 @@ class LoadImages: | |||
ret_val, img0 = self.cap.read() | |||
self.frame += 1 | |||
print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: ', end='') | |||
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: ' | |||
else: | |||
# Read image | |||
self.count += 1 | |||
img0 = cv2.imread(path) # BGR | |||
assert img0 is not None, 'Image Not Found ' + path | |||
print(f'image {self.count}/{self.nf} {path}: ', end='') | |||
assert img0 is not None, f'Image Not Found {path}' | |||
s = f'image {self.count}/{self.nf} {path}: ' | |||
# Padded resize | |||
img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0] | |||
@@ -226,7 +226,7 @@ class LoadImages: | |||
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB | |||
img = np.ascontiguousarray(img) | |||
return path, img, img0, self.cap | |||
return path, img, img0, self.cap, s | |||
def new_video(self, path): | |||
self.frame = 0 | |||
@@ -264,7 +264,7 @@ class LoadWebcam: # for inference | |||
assert ret_val, f'Camera Error {self.pipe}' | |||
img_path = 'webcam.jpg' | |||
print(f'webcam {self.count}: ', end='') | |||
s = f'webcam {self.count}: ' | |||
# Padded resize | |||
img = letterbox(img0, self.img_size, stride=self.stride)[0] | |||
@@ -273,7 +273,7 @@ class LoadWebcam: # for inference | |||
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB | |||
img = np.ascontiguousarray(img) | |||
return img_path, img, img0, None | |||
return img_path, img, img0, None, s | |||
def __len__(self): | |||
return 0 | |||
@@ -298,14 +298,14 @@ class LoadStreams: | |||
self.auto = auto | |||
for i, s in enumerate(sources): # index, source | |||
# Start thread to read frames from video stream | |||
print(f'{i + 1}/{n}: {s}... ', end='') | |||
st = f'{i + 1}/{n}: {s}... ' | |||
if 'youtube.com/' in s or 'youtu.be/' in s: # if source is YouTube video | |||
check_requirements(('pafy', 'youtube_dl')) | |||
import pafy | |||
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL | |||
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam | |||
cap = cv2.VideoCapture(s) | |||
assert cap.isOpened(), f'Failed to open {s}' | |||
assert cap.isOpened(), f'{st}Failed to open {s}' | |||
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |||
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |||
self.fps[i] = max(cap.get(cv2.CAP_PROP_FPS) % 100, 0) or 30.0 # 30 FPS fallback | |||
@@ -313,15 +313,15 @@ class LoadStreams: | |||
_, self.imgs[i] = cap.read() # guarantee first frame | |||
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True) | |||
print(f" success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)") | |||
LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)") | |||
self.threads[i].start() | |||
print('') # newline | |||
LOGGER.info('') # newline | |||
# check for common shapes | |||
s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs]) | |||
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal | |||
if not self.rect: | |||
print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.') | |||
LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.') | |||
def update(self, i, cap, stream): | |||
# Read stream `i` frames in daemon thread | |||
@@ -335,7 +335,7 @@ class LoadStreams: | |||
if success: | |||
self.imgs[i] = im | |||
else: | |||
print('WARNING: Video stream unresponsive, please check your IP camera connection.') | |||
LOGGER.warn('WARNING: Video stream unresponsive, please check your IP camera connection.') | |||
self.imgs[i] *= 0 | |||
cap.open(stream) # re-open stream if signal was lost | |||
time.sleep(1 / self.fps[i]) # wait time | |||
@@ -361,7 +361,7 @@ class LoadStreams: | |||
img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW | |||
img = np.ascontiguousarray(img) | |||
return self.sources, img, img0, None | |||
return self.sources, img, img0, None, '' | |||
def __len__(self): | |||
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years | |||
@@ -666,7 +666,7 @@ def load_image(self, i): | |||
else: # read image | |||
path = self.img_files[i] | |||
im = cv2.imread(path) # BGR | |||
assert im is not None, 'Image Not Found ' + path | |||
assert im is not None, f'Image Not Found {path}' | |||
h0, w0 = im.shape[:2] # orig hw | |||
r = self.img_size / max(h0, w0) # ratio | |||
if r != 1: # if sizes are not equal |
@@ -42,6 +42,16 @@ FILE = Path(__file__).resolve() | |||
ROOT = FILE.parents[1] # YOLOv5 root directory | |||
def set_logging(name=None, verbose=True): | |||
# Sets level and returns logger | |||
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings | |||
logging.basicConfig(format="%(message)s", level=logging.INFO if (verbose and rank in (-1, 0)) else logging.WARN) | |||
return logging.getLogger(name) | |||
LOGGER = set_logging(__name__) # define globally (used in train.py, val.py, detect.py, etc.) | |||
class Profile(contextlib.ContextDecorator): | |||
# Usage: @Profile() decorator or 'with Profile():' context manager | |||
def __enter__(self): | |||
@@ -87,15 +97,9 @@ def methods(instance): | |||
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")] | |||
def set_logging(rank=-1, verbose=True): | |||
logging.basicConfig( | |||
format="%(message)s", | |||
level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN) | |||
def print_args(name, opt): | |||
# Print argparser arguments | |||
print(colorstr(f'{name}: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items())) | |||
LOGGER.info(colorstr(f'{name}: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items())) | |||
def init_seeds(seed=0): |
@@ -25,9 +25,9 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative | |||
from models.experimental import attempt_load | |||
from utils.datasets import create_dataloader | |||
from utils.general import coco80_to_coco91_class, check_dataset, check_img_size, check_requirements, \ | |||
check_suffix, check_yaml, box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, \ | |||
increment_path, colorstr, print_args | |||
from utils.general import box_iou, coco80_to_coco91_class, colorstr, check_dataset, check_img_size, \ | |||
check_requirements, check_suffix, check_yaml, increment_path, non_max_suppression, print_args, scale_coords, \ | |||
xyxy2xywh, xywh2xyxy, LOGGER | |||
from utils.metrics import ap_per_class, ConfusionMatrix | |||
from utils.plots import output_to_target, plot_images, plot_val_study | |||
from utils.torch_utils import select_device, time_sync | |||
@@ -242,18 +242,18 @@ def run(data, | |||
# Print results | |||
pf = '%20s' + '%11i' * 2 + '%11.3g' * 4 # print format | |||
print(pf % ('all', seen, nt.sum(), mp, mr, map50, map)) | |||
LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map)) | |||
# Print results per class | |||
if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats): | |||
for i, c in enumerate(ap_class): | |||
print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i])) | |||
LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i])) | |||
# Print speeds | |||
t = tuple(x / seen * 1E3 for x in dt) # speeds per image | |||
if not training: | |||
shape = (batch_size, 3, imgsz, imgsz) | |||
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t) | |||
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t) | |||
# Plots | |||
if plots: | |||
@@ -265,7 +265,7 @@ def run(data, | |||
w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights | |||
anno_json = str(Path(data.get('path', '../coco')) / 'annotations/instances_val2017.json') # annotations json | |||
pred_json = str(save_dir / f"{w}_predictions.json") # predictions json | |||
print(f'\nEvaluating pycocotools mAP... saving {pred_json}...') | |||
LOGGER.info(f'\nEvaluating pycocotools mAP... saving {pred_json}...') | |||
with open(pred_json, 'w') as f: | |||
json.dump(jdict, f) | |||
@@ -284,13 +284,13 @@ def run(data, | |||
eval.summarize() | |||
map, map50 = eval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5) | |||
except Exception as e: | |||
print(f'pycocotools unable to run: {e}') | |||
LOGGER.info(f'pycocotools unable to run: {e}') | |||
# Return results | |||
model.float() # for training | |||
if not training: | |||
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' | |||
print(f"Results saved to {colorstr('bold', save_dir)}{s}") | |||
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}") | |||
maps = np.zeros(nc) + map | |||
for i, c in enumerate(ap_class): | |||
maps[c] = ap[i] | |||
@@ -327,8 +327,7 @@ def parse_opt(): | |||
def main(opt): | |||
set_logging() | |||
check_requirements(exclude=('tensorboard', 'thop')) | |||
check_requirements(requirements=ROOT / 'requirements.txt', exclude=('tensorboard', 'thop')) | |||
if opt.task in ('train', 'val', 'test'): # run normally | |||
run(**vars(opt)) | |||
@@ -346,7 +345,7 @@ def main(opt): | |||
f = f'study_{Path(opt.data).stem}_{Path(w).stem}.txt' # filename to save to | |||
y = [] # y axis | |||
for i in x: # img-size | |||
print(f'\nRunning {f} point {i}...') | |||
LOGGER.info(f'\nRunning {f} point {i}...') | |||
r, _, t = run(opt.data, weights=w, batch_size=opt.batch_size, imgsz=i, conf_thres=opt.conf_thres, | |||
iou_thres=opt.iou_thres, device=opt.device, save_json=opt.save_json, plots=False) | |||
y.append(r + t) # results and times |