Browse Source

Add colorstr() (#1887)

* Add colorful()

* update

* newline fix

* add git description

* --always

* update loss scaling

* update loss scaling 2

* rename to colorstr()
5.0
Glenn Jocher GitHub 3 years ago
parent
commit
6ab589583c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 60 additions and 22 deletions
  1. +3
    -2
      train.py
  2. +15
    -12
      utils/autoanchor.py
  3. +27
    -1
      utils/general.py
  4. +2
    -4
      utils/loss.py
  5. +13
    -3
      utils/torch_utils.py

+ 3
- 2
train.py View File

check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)


# Model parameters # Model parameters
hyp['cls'] *= nc / 80. # scale hyp['cls'] to class count
hyp['obj'] *= imgsz ** 2 / 640. ** 2 * 3. / nl # scale hyp['obj'] to image size and output layers
hyp['box'] *= 3. / nl # scale to layers
hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
model.nc = nc # attach number of classes to model model.nc = nc # attach number of classes to model
model.hyp = hyp # attach hyperparameters to model model.hyp = hyp # attach hyperparameters to model
model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou) model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)

+ 15
- 12
utils/autoanchor.py View File

from scipy.cluster.vq import kmeans from scipy.cluster.vq import kmeans
from tqdm import tqdm from tqdm import tqdm


from utils.general import colorstr



def check_anchor_order(m): def check_anchor_order(m):
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary


def check_anchors(dataset, model, thr=4.0, imgsz=640): def check_anchors(dataset, model, thr=4.0, imgsz=640):
# Check anchor fit to data, recompute if necessary # Check anchor fit to data, recompute if necessary
print('\nAnalyzing anchors... ', end='')
prefix = colorstr('blue', 'bold', 'autoanchor') + ': '
print(f'\n{prefix}Analyzing anchors... ', end='')
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect() m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True) shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
return bpr, aat return bpr, aat


bpr, aat = metric(m.anchor_grid.clone().cpu().view(-1, 2)) bpr, aat = metric(m.anchor_grid.clone().cpu().view(-1, 2))
print('anchors/target = %.2f, Best Possible Recall (BPR) = %.4f' % (aat, bpr), end='')
print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
if bpr < 0.98: # threshold to recompute if bpr < 0.98: # threshold to recompute
print('. Attempting to improve anchors, please wait...') print('. Attempting to improve anchors, please wait...')
na = m.anchor_grid.numel() // 2 # number of anchors na = m.anchor_grid.numel() // 2 # number of anchors
m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
check_anchor_order(m) check_anchor_order(m)
print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
else: else:
print('Original anchors better than new anchors. Proceeding with original anchors.')
print(f'{prefix}Original anchors better than new anchors. Proceeding with original anchors.')
print('') # newline print('') # newline




from utils.autoanchor import *; _ = kmean_anchors() from utils.autoanchor import *; _ = kmean_anchors()
""" """
thr = 1. / thr thr = 1. / thr
prefix = colorstr('blue', 'bold', 'autoanchor') + ': '


def metric(k, wh): # compute metrics def metric(k, wh): # compute metrics
r = wh[:, None] / k[None] r = wh[:, None] / k[None]
k = k[np.argsort(k.prod(1))] # sort small to large k = k[np.argsort(k.prod(1))] # sort small to large
x, best = metric(k, wh0) x, best = metric(k, wh0)
bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
print('thr=%.2f: %.4f best possible recall, %.2f anchors past thr' % (thr, bpr, aat))
print('n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thr=%.3f-mean: ' %
(n, img_size, x.mean(), best.mean(), x[x > thr].mean()), end='')
print(f'{prefix}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr')
print(f'{prefix}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, '
f'past_thr={x[x > thr].mean():.3f}-mean: ', end='')
for i, x in enumerate(k): for i, x in enumerate(k):
print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
return k return k
# Filter # Filter
i = (wh0 < 3.0).any(1).sum() i = (wh0 < 3.0).any(1).sum()
if i: if i:
print('WARNING: Extremely small objects found. '
'%g of %g labels are < 3 pixels in width or height.' % (i, len(wh0)))
print(f'{prefix}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.')
wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
# wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1 # wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1


# Kmeans calculation # Kmeans calculation
print('Running kmeans for %g anchors on %g points...' % (n, len(wh)))
print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...')
s = wh.std(0) # sigmas for whitening s = wh.std(0) # sigmas for whitening
k, dist = kmeans(wh / s, n, iter=30) # points, mean distance k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
k *= s k *= s
# Evolve # Evolve
npr = np.random npr = np.random
f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
pbar = tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm') # progress bar
pbar = tqdm(range(gen), desc=f'{prefix}Evolving anchors with Genetic Algorithm:') # progress bar
for _ in pbar: for _ in pbar:
v = np.ones(sh) v = np.ones(sh)
while (v == 1).all(): # mutate until a change occurs (prevent duplicates) while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
fg = anchor_fitness(kg) fg = anchor_fitness(kg)
if fg > f: if fg > f:
f, k = fg, kg.copy() f, k = fg, kg.copy()
pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f
pbar.desc = f'{prefix}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
if verbose: if verbose:
print_results(k) print_results(k)



+ 27
- 1
utils/general.py View File



def check_git_status(): def check_git_status():
# Suggest 'git pull' if repo is out of date # Suggest 'git pull' if repo is out of date
if platform.system() in ['Linux', 'Darwin'] and not os.path.isfile('/.dockerenv'):
if Path('.git').exists() and platform.system() in ['Linux', 'Darwin'] and not Path('/.dockerenv').is_file():
s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8') s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
if 'Your branch is behind' in s: if 'Your branch is behind' in s:
print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n') print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1 return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1




def colorstr(*input):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
*prefix, str = input # color arguments, string
colors = {'black': '\033[30m', # basic colors
'red': '\033[31m',
'green': '\033[32m',
'yellow': '\033[33m',
'blue': '\033[34m',
'magenta': '\033[35m',
'cyan': '\033[36m',
'white': '\033[37m',
'bright_black': '\033[90m', # bright colors
'bright_red': '\033[91m',
'bright_green': '\033[92m',
'bright_yellow': '\033[93m',
'bright_blue': '\033[94m',
'bright_magenta': '\033[95m',
'bright_cyan': '\033[96m',
'bright_white': '\033[97m',
'end': '\033[0m', # misc
'bold': '\033[1m',
'undelrine': '\033[4m'}

return ''.join(colors[x] for x in prefix) + str + colors['end']


def labels_to_class_weights(labels, nc=80): def labels_to_class_weights(labels, nc=80):
# Get class weights (inverse frequency) from training labels # Get class weights (inverse frequency) from training labels
if labels[0] is None: # no labels loaded if labels[0] is None: # no labels loaded

+ 2
- 4
utils/loss.py View File



# Losses # Losses
nt = 0 # number of targets nt = 0 # number of targets
no = len(p) # number of outputs
balance = [4.0, 1.0, 0.3, 0.1, 0.03] # P3-P7 balance = [4.0, 1.0, 0.3, 0.1, 0.03] # P3-P7
for i, pi in enumerate(p): # layer index, layer predictions for i, pi in enumerate(p): # layer index, layer predictions
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx b, a, gj, gi = indices[i] # image, anchor, gridy, gridx


lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss


s = 3 / no # output count scaling
lbox *= h['box'] * s
lbox *= h['box']
lobj *= h['obj'] lobj *= h['obj']
lcls *= h['cls'] * s
lcls *= h['cls']
bs = tobj.shape[0] # batch size bs = tobj.shape[0] # batch size


loss = lbox + lobj + lcls loss = lbox + lobj + lcls

+ 13
- 3
utils/torch_utils.py View File

import logging import logging
import math import math
import os import os
import subprocess
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from pathlib import Path


import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
cudnn.benchmark, cudnn.deterministic = True, False cudnn.benchmark, cudnn.deterministic = True, False




def git_describe():
# return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
if Path('.git').exists():
return subprocess.check_output('git describe --tags --long --always', shell=True).decode('utf-8')[:-1]
else:
return ''


def select_device(device='', batch_size=None): def select_device(device='', batch_size=None):
# device = 'cpu' or '0' or '0,1,2,3' # device = 'cpu' or '0' or '0,1,2,3'
s = f'Using torch {torch.__version__} ' # string
s = f'YOLOv5 {git_describe()} torch {torch.__version__} ' # string
cpu = device.lower() == 'cpu' cpu = device.lower() == 'cpu'
if cpu: if cpu:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
p = torch.cuda.get_device_properties(i) p = torch.cuda.get_device_properties(i)
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
else: else:
s += 'CPU'
s += 'CPU\n'


logger.info(f'{s}\n') # skip a line
logger.info(s) # skip a line
return torch.device('cuda:0' if cuda else 'cpu') return torch.device('cuda:0' if cuda else 'cpu')





Loading…
Cancel
Save