Browse Source

Add `@threaded` decorator (#7813)

* Add `@threaded` decorator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
modifyDataloader
Glenn Jocher GitHub 2 years ago
parent
commit
4a295b1a89
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 25 additions and 17 deletions
  1. +2
    -2
      train.py
  2. +11
    -0
      utils/general.py
  3. +3
    -4
      utils/loggers/__init__.py
  4. +7
    -6
      utils/plots.py
  5. +2
    -5
      val.py

+ 2
- 2
train.py View File

from utils.downloads import attempt_download from utils.downloads import attempt_download
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path, check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path,
init_seeds, intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights,
methods, one_cycle, print_args, print_mutation, strip_optimizer)
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
one_cycle, print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss from utils.loss import ComputeLoss

+ 11
- 0
utils/general.py View File

import re import re
import shutil import shutil
import signal import signal
import threading
import time import time
import urllib import urllib
from datetime import datetime from datetime import datetime
return handler return handler




def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread

return wrapper


def methods(instance): def methods(instance):
# Get class/instance methods # Get class/instance methods
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")] return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]

+ 3
- 4
utils/loggers/__init__.py View File



import os import os
import warnings import warnings
from threading import Thread


import pkg_resources as pkg import pkg_resources as pkg
import torch import torch
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), []) self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
if ni < 3: if ni < 3:
f = self.save_dir / f'train_batch{ni}.jpg' # filename f = self.save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
plot_images(imgs, targets, paths, f)
if self.wandb and ni == 10: if self.wandb and ni == 10:
files = sorted(self.save_dir.glob('train*.jpg')) files = sorted(self.save_dir.glob('train*.jpg'))
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]}) self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})


def on_fit_epoch_end(self, vals, epoch, best_fitness, fi): def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
# Callback runs at the end of each fit (train+val) epoch # Callback runs at the end of each fit (train+val) epoch
x = {k: v for k, v in zip(self.keys, vals)} # dict
x = dict(zip(self.keys, vals))
if self.csv: if self.csv:
file = self.save_dir / 'results.csv' file = self.save_dir / 'results.csv'
n = len(x) + 1 # number of cols n = len(x) + 1 # number of cols
self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC') self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')


if self.wandb: if self.wandb:
self.wandb.log({k: v for k, v in zip(self.keys[3:10], results)}) # log best.pt val results
self.wandb.log(dict(zip(self.keys[3:10], results)))
self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]}) self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
# Calling wandb.log. TODO: Refactor this into WandbLogger.log_model # Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
if not self.opt.evolve: if not self.opt.evolve:

+ 7
- 6
utils/plots.py View File

from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont


from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords, from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords,
increment_path, is_ascii, try_except, xywh2xyxy, xyxy2xywh)
increment_path, is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh)
from utils.metrics import fitness from utils.metrics import fitness


# Settings # Settings
# Ultralytics color palette https://ultralytics.com/ # Ultralytics color palette https://ultralytics.com/
def __init__(self): def __init__(self):
# hex = matplotlib.colors.TABLEAU_COLORS.values() # hex = matplotlib.colors.TABLEAU_COLORS.values()
hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
self.palette = [self.hex2rgb('#' + c) for c in hex]
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
self.n = len(self.palette) self.n = len(self.palette)


def __call__(self, i, bgr=False): def __call__(self, i, bgr=False):
if label: if label:
tf = max(self.lw - 1, 1) # font thickness tf = max(self.lw - 1, 1) # font thickness
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
outside = p1[1] - h - 3 >= 0 # label fits outside box
outside = p1[1] - h >= 3
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
cv2.putText(self.im, cv2.putText(self.im,
return np.array(targets) return np.array(targets)




@threaded
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16): def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16):
# Plot image grid with labels # Plot image grid with labels
if isinstance(images, torch.Tensor): if isinstance(images, torch.Tensor):
ax = ax.ravel() ax = ax.ravel()
files = list(save_dir.glob('results*.csv')) files = list(save_dir.glob('results*.csv'))
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.' assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
for fi, f in enumerate(files):
for f in files:
try: try:
data = pd.read_csv(f) data = pd.read_csv(f)
s = [x.strip() for x in data.columns] s = [x.strip() for x in data.columns]

+ 2
- 5
val.py View File

import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Thread


import numpy as np import numpy as np
import torch import torch


# Plot images # Plot images
if plots and batch_i < 3: if plots and batch_i < 3:
f = save_dir / f'val_batch{batch_i}_labels.jpg' # labels
Thread(target=plot_images, args=(im, targets, paths, f, names), daemon=True).start()
f = save_dir / f'val_batch{batch_i}_pred.jpg' # predictions
Thread(target=plot_images, args=(im, output_to_target(out), paths, f, names), daemon=True).start()
plot_images(im, targets, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) # labels
plot_images(im, output_to_target(out), paths, save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred


callbacks.run('on_val_batch_end') callbacks.run('on_val_batch_end')



Loading…
Cancel
Save