Browse Source

Update `get_loggers()` (#4854)

* Update `set_logging()`

* Update export.py

* pre-commit fixes

* Update LoadImages

* Update LoadStreams

* Update print_args

* Single LOGGER definition

* yolo.py fix

Co-authored-by: pre-commit <pre-commit@example.com>
modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
7b1f7aec46
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 84 additions and 91 deletions
  1. +8
    -9
      detect.py
  2. +31
    -32
      export.py
  3. +2
    -5
      models/tf.py
  4. +1
    -4
      models/yolo.py
  5. +5
    -7
      train.py
  6. +15
    -15
      utils/datasets.py
  7. +11
    -7
      utils/general.py
  8. +11
    -12
      val.py

+ 8
- 9
detect.py View File

@@ -25,8 +25,7 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.experimental import attempt_load
from utils.datasets import LoadImages, LoadStreams
from utils.general import apply_classifier, check_img_size, check_imshow, check_requirements, check_suffix, colorstr, \
increment_path, non_max_suppression, print_args, save_one_box, scale_coords, set_logging, \
strip_optimizer, xyxy2xywh
increment_path, non_max_suppression, print_args, save_one_box, scale_coords, strip_optimizer, xyxy2xywh, LOGGER
from utils.plots import Annotator, colors
from utils.torch_utils import load_classifier, select_device, time_sync

@@ -68,7 +67,6 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir

# Initialize
set_logging()
device = select_device(device)
half &= device.type != 'cpu' # half precision only supported on CUDA

@@ -132,7 +130,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
if pt and device.type != 'cpu':
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters()))) # run once
dt, seen = [0.0, 0.0, 0.0], 0
for path, img, im0s, vid_cap in dataset:
for path, img, im0s, vid_cap, s in dataset:
t1 = time_sync()
if onnx:
img = img.astype('float32')
@@ -191,9 +189,10 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
for i, det in enumerate(pred): # per image
seen += 1
if webcam: # batch_size >= 1
p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), dataset.count
p, im0, frame = path[i], im0s[i].copy(), dataset.count
s += f'{i}: '
else:
p, s, im0, frame = path, '', im0s.copy(), getattr(dataset, 'frame', 0)
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)

p = Path(p) # to Path
save_path = str(save_dir / p.name) # img.jpg
@@ -227,7 +226,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)

# Print time (inference-only)
print(f'{s}Done. ({t3 - t2:.3f}s)')
LOGGER.info(f'{s}Done. ({t3 - t2:.3f}s)')

# Stream results
im0 = annotator.result()
@@ -256,10 +255,10 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)

# Print results
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
if save_txt or save_img:
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
print(f"Results saved to {colorstr('bold', save_dir)}{s}")
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
if update:
strip_optimizer(weights) # update model (to fix SourceChangeWarning)


+ 31
- 32
export.py View File

@@ -42,23 +42,23 @@ from models.experimental import attempt_load
from models.yolo import Detect
from utils.activations import SiLU
from utils.datasets import LoadImages
from utils.general import colorstr, check_dataset, check_img_size, check_requirements, file_size, print_args, \
set_logging, url2file
from utils.general import check_dataset, check_img_size, check_requirements, colorstr, file_size, print_args, \
url2file, LOGGER
from utils.torch_utils import select_device


def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
# YOLOv5 TorchScript model export
try:
print(f'\n{prefix} starting export with torch {torch.__version__}...')
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
f = file.with_suffix('.torchscript.pt')

ts = torch.jit.trace(model, im, strict=False)
(optimize_for_mobile(ts) if optimize else ts).save(f)

print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'{prefix} export failure: {e}')
LOGGER.info(f'{prefix} export failure: {e}')


def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
@@ -67,7 +67,7 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
check_requirements(('onnx',))
import onnx

print(f'\n{prefix} starting export with onnx {onnx.__version__}...')
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
f = file.with_suffix('.onnx')

torch.onnx.export(model, im, f, verbose=False, opset_version=opset,
@@ -82,7 +82,7 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
# Checks
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
# print(onnx.helper.printable_graph(model_onnx.graph)) # print
# LOGGER.info(onnx.helper.printable_graph(model_onnx.graph)) # print

# Simplify
if simplify:
@@ -90,7 +90,7 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
check_requirements(('onnx-simplifier',))
import onnxsim

print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(
model_onnx,
dynamic_input_shape=dynamic,
@@ -98,11 +98,11 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
print(f'{prefix} simplifier failure: {e}')
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
print(f"{prefix} run --dynamic ONNX model inference with: 'python detect.py --weights {f}'")
LOGGER.info(f'{prefix} simplifier failure: {e}')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
LOGGER.info(f"{prefix} run --dynamic ONNX model inference with: 'python detect.py --weights {f}'")
except Exception as e:
print(f'{prefix} export failure: {e}')
LOGGER.info(f'{prefix} export failure: {e}')


def export_coreml(model, im, file, prefix=colorstr('CoreML:')):
@@ -112,7 +112,7 @@ def export_coreml(model, im, file, prefix=colorstr('CoreML:')):
check_requirements(('coremltools',))
import coremltools as ct

print(f'\n{prefix} starting export with coremltools {ct.__version__}...')
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
f = file.with_suffix('.mlmodel')

model.train() # CoreML exports should be placed in model.train() mode
@@ -120,9 +120,9 @@ def export_coreml(model, im, file, prefix=colorstr('CoreML:')):
ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255.0, bias=[0, 0, 0])])
ct_model.save(f)

print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'\n{prefix} export failure: {e}')
LOGGER.info(f'\n{prefix} export failure: {e}')

return ct_model

@@ -137,7 +137,7 @@ def export_saved_model(model, im, file, dynamic,
from tensorflow import keras
from models.tf import TFModel, TFDetect

print(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
f = str(file).replace('.pt', '_saved_model')
batch_size, ch, *imgsz = list(im.shape) # BCHW

@@ -151,9 +151,9 @@ def export_saved_model(model, im, file, dynamic,
keras_model.summary()
keras_model.save(f, save_format='tf')

print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'\n{prefix} export failure: {e}')
LOGGER.info(f'\n{prefix} export failure: {e}')

return keras_model

@@ -164,7 +164,7 @@ def export_pb(keras_model, im, file, prefix=colorstr('TensorFlow GraphDef:')):
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

print(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
f = file.with_suffix('.pb')

m = tf.function(lambda x: keras_model(x)) # full model
@@ -173,9 +173,9 @@ def export_pb(keras_model, im, file, prefix=colorstr('TensorFlow GraphDef:')):
frozen_func.graph.as_graph_def()
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)

print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'\n{prefix} export failure: {e}')
LOGGER.info(f'\n{prefix} export failure: {e}')


def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('TensorFlow Lite:')):
@@ -184,7 +184,7 @@ def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('Te
import tensorflow as tf
from models.tf import representative_dataset_gen

print(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
batch_size, ch, *imgsz = list(im.shape) # BCHW
f = str(file).replace('.pt', '-fp16.tflite')

@@ -204,10 +204,10 @@ def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('Te

tflite_model = converter.convert()
open(f, "wb").write(tflite_model)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')

except Exception as e:
print(f'\n{prefix} export failure: {e}')
LOGGER.info(f'\n{prefix} export failure: {e}')


def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
@@ -217,7 +217,7 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
import re
import tensorflowjs as tfjs

print(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
f = str(file).replace('.pt', '_web_model') # js dir
f_pb = file.with_suffix('.pb') # *.pb path
f_json = f + '/model.json' # *.json path
@@ -240,9 +240,9 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
json)
j.write(subst)

print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'\n{prefix} export failure: {e}')
LOGGER.info(f'\n{prefix} export failure: {e}')


@torch.no_grad()
@@ -297,7 +297,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'

for _ in range(2):
y = model(im) # dry runs
print(f"\n{colorstr('PyTorch:')} starting from {file} ({file_size(file):.1f} MB)")
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} ({file_size(file):.1f} MB)")

# Exports
if 'torchscript' in include:
@@ -322,9 +322,9 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
export_tfjs(model, im, file)

# Finish
print(f'\nExport complete ({time.time() - t:.2f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
f'\nVisualize with https://netron.app')
LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
f'\nVisualize with https://netron.app')


def parse_opt():
@@ -355,7 +355,6 @@ def parse_opt():


def main(opt):
set_logging()
run(**vars(opt))



+ 2
- 5
models/tf.py View File

@@ -31,11 +31,9 @@ from tensorflow import keras
from models.common import Bottleneck, BottleneckCSP, Concat, Conv, C3, DWConv, Focus, SPP, SPPF, autopad
from models.experimental import CrossConv, MixConv2d, attempt_load
from models.yolo import Detect
from utils.general import make_divisible, print_args, set_logging
from utils.general import make_divisible, print_args, LOGGER
from utils.activations import SiLU

LOGGER = logging.getLogger(__name__)


class TFBN(keras.layers.Layer):
# TensorFlow BatchNormalization wrapper
@@ -336,7 +334,7 @@ class TFModel:

# Define model
if nc and nc != self.yaml['nc']:
print(f"Overriding {cfg} nc={self.yaml['nc']} with nc={nc}")
LOGGER.info(f"Overriding {cfg} nc={self.yaml['nc']} with nc={nc}")
self.yaml['nc'] = nc # override yaml value
self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model, imgsz=imgsz)

@@ -457,7 +455,6 @@ def parse_opt():


def main(opt):
set_logging()
run(**vars(opt))



+ 1
- 4
models/yolo.py View File

@@ -20,7 +20,7 @@ if str(ROOT) not in sys.path:
from models.common import *
from models.experimental import *
from utils.autoanchor import check_anchor_order
from utils.general import check_yaml, make_divisible, print_args, set_logging, check_version
from utils.general import check_version, check_yaml, make_divisible, print_args, LOGGER
from utils.plots import feature_visualization
from utils.torch_utils import copy_attr, fuse_conv_and_bn, initialize_weights, model_info, scale_img, \
select_device, time_sync
@@ -30,8 +30,6 @@ try:
except ImportError:
thop = None

LOGGER = logging.getLogger(__name__)


class Detect(nn.Module):
stride = None # strides computed during build
@@ -311,7 +309,6 @@ if __name__ == '__main__':
opt = parser.parse_args()
opt.cfg = check_yaml(opt.cfg) # check YAML
print_args(FILE.stem, opt)
set_logging()
device = select_device(opt.device)

# Create model

+ 5
- 7
train.py View File

@@ -40,7 +40,7 @@ from utils.autobatch import check_train_batch_size
from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \
check_file, check_yaml, check_suffix, print_args, print_mutation, set_logging, one_cycle, colorstr, methods
check_file, check_yaml, check_suffix, print_args, print_mutation, one_cycle, colorstr, methods, LOGGER
from utils.downloads import attempt_download
from utils.loss import ComputeLoss
from utils.plots import plot_labels, plot_evolve
@@ -51,7 +51,6 @@ from utils.metrics import fitness
from utils.loggers import Loggers
from utils.callbacks import Callbacks

LOGGER = logging.getLogger(__name__)
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
@@ -129,7 +128,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
for k, v in model.named_parameters():
v.requires_grad = True # train all layers
if any(x in k for x in freeze):
print(f'freezing {k}')
LOGGER.info(f'freezing {k}')
v.requires_grad = False

# Image size
@@ -485,7 +484,6 @@ def parse_opt(known=False):

def main(opt, callbacks=Callbacks()):
# Checks
set_logging(RANK)
if RANK in [-1, 0]:
print_args(FILE.stem, opt)
check_git_status()
@@ -609,9 +607,9 @@ def main(opt, callbacks=Callbacks()):

# Plot results
plot_evolve(evolve_csv)
print(f'Hyperparameter evolution finished\n'
f"Results saved to {colorstr('bold', save_dir)}\n"
f'Use best hyperparameters example: $ python train.py --hyp {evolve_yaml}')
LOGGER.info(f'Hyperparameter evolution finished\n'
f"Results saved to {colorstr('bold', save_dir)}\n"
f'Use best hyperparameters example: $ python train.py --hyp {evolve_yaml}')


def run(**kwargs):

+ 15
- 15
utils/datasets.py View File

@@ -28,7 +28,7 @@ from tqdm import tqdm

from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
from utils.general import check_dataset, check_requirements, check_yaml, clean_str, segments2boxes, \
xywh2xyxy, xywhn2xyxy, xyxy2xywhn, xyn2xy
xywh2xyxy, xywhn2xyxy, xyxy2xywhn, xyn2xy, LOGGER
from utils.torch_utils import torch_distributed_zero_first

# Parameters
@@ -210,14 +210,14 @@ class LoadImages:
ret_val, img0 = self.cap.read()

self.frame += 1
print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: ', end='')
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '

else:
# Read image
self.count += 1
img0 = cv2.imread(path) # BGR
assert img0 is not None, 'Image Not Found ' + path
print(f'image {self.count}/{self.nf} {path}: ', end='')
assert img0 is not None, f'Image Not Found {path}'
s = f'image {self.count}/{self.nf} {path}: '

# Padded resize
img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]
@@ -226,7 +226,7 @@ class LoadImages:
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img = np.ascontiguousarray(img)

return path, img, img0, self.cap
return path, img, img0, self.cap, s

def new_video(self, path):
self.frame = 0
@@ -264,7 +264,7 @@ class LoadWebcam: # for inference
# Print
assert ret_val, f'Camera Error {self.pipe}'
img_path = 'webcam.jpg'
print(f'webcam {self.count}: ', end='')
s = f'webcam {self.count}: '

# Padded resize
img = letterbox(img0, self.img_size, stride=self.stride)[0]
@@ -273,7 +273,7 @@ class LoadWebcam: # for inference
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img = np.ascontiguousarray(img)

return img_path, img, img0, None
return img_path, img, img0, None, s

def __len__(self):
return 0
@@ -298,14 +298,14 @@ class LoadStreams:
self.auto = auto
for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream
print(f'{i + 1}/{n}: {s}... ', end='')
st = f'{i + 1}/{n}: {s}... '
if 'youtube.com/' in s or 'youtu.be/' in s: # if source is YouTube video
check_requirements(('pafy', 'youtube_dl'))
import pafy
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
cap = cv2.VideoCapture(s)
assert cap.isOpened(), f'Failed to open {s}'
assert cap.isOpened(), f'{st}Failed to open {s}'
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
self.fps[i] = max(cap.get(cv2.CAP_PROP_FPS) % 100, 0) or 30.0 # 30 FPS fallback
@@ -313,15 +313,15 @@ class LoadStreams:

_, self.imgs[i] = cap.read() # guarantee first frame
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
print(f" success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
self.threads[i].start()
print('') # newline
LOGGER.info('') # newline

# check for common shapes
s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs])
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
if not self.rect:
print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')

def update(self, i, cap, stream):
# Read stream `i` frames in daemon thread
@@ -335,7 +335,7 @@ class LoadStreams:
if success:
self.imgs[i] = im
else:
print('WARNING: Video stream unresponsive, please check your IP camera connection.')
LOGGER.warn('WARNING: Video stream unresponsive, please check your IP camera connection.')
self.imgs[i] *= 0
cap.open(stream) # re-open stream if signal was lost
time.sleep(1 / self.fps[i]) # wait time
@@ -361,7 +361,7 @@ class LoadStreams:
img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
img = np.ascontiguousarray(img)

return self.sources, img, img0, None
return self.sources, img, img0, None, ''

def __len__(self):
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
@@ -666,7 +666,7 @@ def load_image(self, i):
else: # read image
path = self.img_files[i]
im = cv2.imread(path) # BGR
assert im is not None, 'Image Not Found ' + path
assert im is not None, f'Image Not Found {path}'
h0, w0 = im.shape[:2] # orig hw
r = self.img_size / max(h0, w0) # ratio
if r != 1: # if sizes are not equal

+ 11
- 7
utils/general.py View File

@@ -42,6 +42,16 @@ FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLOv5 root directory


def set_logging(name=None, verbose=True):
# Sets level and returns logger
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
logging.basicConfig(format="%(message)s", level=logging.INFO if (verbose and rank in (-1, 0)) else logging.WARN)
return logging.getLogger(name)


LOGGER = set_logging(__name__) # define globally (used in train.py, val.py, detect.py, etc.)


class Profile(contextlib.ContextDecorator):
# Usage: @Profile() decorator or 'with Profile():' context manager
def __enter__(self):
@@ -87,15 +97,9 @@ def methods(instance):
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]


def set_logging(rank=-1, verbose=True):
logging.basicConfig(
format="%(message)s",
level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)


def print_args(name, opt):
# Print argparser arguments
print(colorstr(f'{name}: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
LOGGER.info(colorstr(f'{name}: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))


def init_seeds(seed=0):

+ 11
- 12
val.py View File

@@ -25,9 +25,9 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative

from models.experimental import attempt_load
from utils.datasets import create_dataloader
from utils.general import coco80_to_coco91_class, check_dataset, check_img_size, check_requirements, \
check_suffix, check_yaml, box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, \
increment_path, colorstr, print_args
from utils.general import box_iou, coco80_to_coco91_class, colorstr, check_dataset, check_img_size, \
check_requirements, check_suffix, check_yaml, increment_path, non_max_suppression, print_args, scale_coords, \
xyxy2xywh, xywh2xyxy, LOGGER
from utils.metrics import ap_per_class, ConfusionMatrix
from utils.plots import output_to_target, plot_images, plot_val_study
from utils.torch_utils import select_device, time_sync
@@ -242,18 +242,18 @@ def run(data,

# Print results
pf = '%20s' + '%11i' * 2 + '%11.3g' * 4 # print format
print(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map))

# Print results per class
if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats):
for i, c in enumerate(ap_class):
print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))

# Print speeds
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
if not training:
shape = (batch_size, 3, imgsz, imgsz)
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)

# Plots
if plots:
@@ -265,7 +265,7 @@ def run(data,
w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
anno_json = str(Path(data.get('path', '../coco')) / 'annotations/instances_val2017.json') # annotations json
pred_json = str(save_dir / f"{w}_predictions.json") # predictions json
print(f'\nEvaluating pycocotools mAP... saving {pred_json}...')
LOGGER.info(f'\nEvaluating pycocotools mAP... saving {pred_json}...')
with open(pred_json, 'w') as f:
json.dump(jdict, f)

@@ -284,13 +284,13 @@ def run(data,
eval.summarize()
map, map50 = eval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5)
except Exception as e:
print(f'pycocotools unable to run: {e}')
LOGGER.info(f'pycocotools unable to run: {e}')

# Return results
model.float() # for training
if not training:
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
print(f"Results saved to {colorstr('bold', save_dir)}{s}")
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
maps = np.zeros(nc) + map
for i, c in enumerate(ap_class):
maps[c] = ap[i]
@@ -327,8 +327,7 @@ def parse_opt():


def main(opt):
set_logging()
check_requirements(exclude=('tensorboard', 'thop'))
check_requirements(requirements=ROOT / 'requirements.txt', exclude=('tensorboard', 'thop'))

if opt.task in ('train', 'val', 'test'): # run normally
run(**vars(opt))
@@ -346,7 +345,7 @@ def main(opt):
f = f'study_{Path(opt.data).stem}_{Path(w).stem}.txt' # filename to save to
y = [] # y axis
for i in x: # img-size
print(f'\nRunning {f} point {i}...')
LOGGER.info(f'\nRunning {f} point {i}...')
r, _, t = run(opt.data, weights=w, batch_size=opt.batch_size, imgsz=i, conf_thres=opt.conf_thres,
iou_thres=opt.iou_thres, device=opt.device, save_json=opt.save_json, plots=False)
y.append(r + t) # results and times

Loading…
Cancel
Save