Преглед изворни кода

Add colorstr() (#1887)

* Add colorful()

* update

* newline fix

* add git description

* --always

* update loss scaling

* update loss scaling 2

* rename to colorstr()
5.0
Glenn Jocher GitHub пре 3 година
родитељ
комит
6ab589583c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 измењених фајлова са 60 додато и 22 уклоњено
  1. +3
    -2
      train.py
  2. +15
    -12
      utils/autoanchor.py
  3. +27
    -1
      utils/general.py
  4. +2
    -4
      utils/loss.py
  5. +13
    -3
      utils/torch_utils.py

+ 3
- 2
train.py Прегледај датотеку

@@ -216,8 +216,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)

# Model parameters
hyp['cls'] *= nc / 80. # scale hyp['cls'] to class count
hyp['obj'] *= imgsz ** 2 / 640. ** 2 * 3. / nl # scale hyp['obj'] to image size and output layers
hyp['box'] *= 3. / nl # scale to layers
hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
model.nc = nc # attach number of classes to model
model.hyp = hyp # attach hyperparameters to model
model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)

+ 15
- 12
utils/autoanchor.py Прегледај датотеку

@@ -6,6 +6,8 @@ import yaml
from scipy.cluster.vq import kmeans
from tqdm import tqdm

from utils.general import colorstr


def check_anchor_order(m):
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
@@ -20,7 +22,8 @@ def check_anchor_order(m):

def check_anchors(dataset, model, thr=4.0, imgsz=640):
# Check anchor fit to data, recompute if necessary
print('\nAnalyzing anchors... ', end='')
prefix = colorstr('blue', 'bold', 'autoanchor') + ': '
print(f'\n{prefix}Analyzing anchors... ', end='')
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
@@ -35,7 +38,7 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
return bpr, aat

bpr, aat = metric(m.anchor_grid.clone().cpu().view(-1, 2))
print('anchors/target = %.2f, Best Possible Recall (BPR) = %.4f' % (aat, bpr), end='')
print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
if bpr < 0.98: # threshold to recompute
print('. Attempting to improve anchors, please wait...')
na = m.anchor_grid.numel() // 2 # number of anchors
@@ -46,9 +49,9 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
check_anchor_order(m)
print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
else:
print('Original anchors better than new anchors. Proceeding with original anchors.')
print(f'{prefix}Original anchors better than new anchors. Proceeding with original anchors.')
print('') # newline


@@ -70,6 +73,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
from utils.autoanchor import *; _ = kmean_anchors()
"""
thr = 1. / thr
prefix = colorstr('blue', 'bold', 'autoanchor') + ': '

def metric(k, wh): # compute metrics
r = wh[:, None] / k[None]
@@ -85,9 +89,9 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
k = k[np.argsort(k.prod(1))] # sort small to large
x, best = metric(k, wh0)
bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
print('thr=%.2f: %.4f best possible recall, %.2f anchors past thr' % (thr, bpr, aat))
print('n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thr=%.3f-mean: ' %
(n, img_size, x.mean(), best.mean(), x[x > thr].mean()), end='')
print(f'{prefix}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr')
print(f'{prefix}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, '
f'past_thr={x[x > thr].mean():.3f}-mean: ', end='')
for i, x in enumerate(k):
print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
return k
@@ -107,13 +111,12 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
# Filter
i = (wh0 < 3.0).any(1).sum()
if i:
print('WARNING: Extremely small objects found. '
'%g of %g labels are < 3 pixels in width or height.' % (i, len(wh0)))
print(f'{prefix}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.')
wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
# wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1

# Kmeans calculation
print('Running kmeans for %g anchors on %g points...' % (n, len(wh)))
print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...')
s = wh.std(0) # sigmas for whitening
k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
k *= s
@@ -136,7 +139,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
# Evolve
npr = np.random
f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
pbar = tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm') # progress bar
pbar = tqdm(range(gen), desc=f'{prefix}Evolving anchors with Genetic Algorithm:') # progress bar
for _ in pbar:
v = np.ones(sh)
while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
@@ -145,7 +148,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
fg = anchor_fitness(kg)
if fg > f:
f, k = fg, kg.copy()
pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f
pbar.desc = f'{prefix}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
if verbose:
print_results(k)


+ 27
- 1
utils/general.py Прегледај датотеку

@@ -47,7 +47,7 @@ def get_latest_run(search_dir='.'):

def check_git_status():
# Suggest 'git pull' if repo is out of date
if platform.system() in ['Linux', 'Darwin'] and not os.path.isfile('/.dockerenv'):
if Path('.git').exists() and platform.system() in ['Linux', 'Darwin'] and not Path('/.dockerenv').is_file():
s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
if 'Your branch is behind' in s:
print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
@@ -115,6 +115,32 @@ def one_cycle(y1=0.0, y2=1.0, steps=100):
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1


def colorstr(*input):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
*prefix, str = input # color arguments, string
colors = {'black': '\033[30m', # basic colors
'red': '\033[31m',
'green': '\033[32m',
'yellow': '\033[33m',
'blue': '\033[34m',
'magenta': '\033[35m',
'cyan': '\033[36m',
'white': '\033[37m',
'bright_black': '\033[90m', # bright colors
'bright_red': '\033[91m',
'bright_green': '\033[92m',
'bright_yellow': '\033[93m',
'bright_blue': '\033[94m',
'bright_magenta': '\033[95m',
'bright_cyan': '\033[96m',
'bright_white': '\033[97m',
'end': '\033[0m', # misc
'bold': '\033[1m',
'undelrine': '\033[4m'}

return ''.join(colors[x] for x in prefix) + str + colors['end']


def labels_to_class_weights(labels, nc=80):
# Get class weights (inverse frequency) from training labels
if labels[0] is None: # no labels loaded

+ 2
- 4
utils/loss.py Прегледај датотеку

@@ -105,7 +105,6 @@ def compute_loss(p, targets, model): # predictions, targets, model

# Losses
nt = 0 # number of targets
no = len(p) # number of outputs
balance = [4.0, 1.0, 0.3, 0.1, 0.03] # P3-P7
for i, pi in enumerate(p): # layer index, layer predictions
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
@@ -138,10 +137,9 @@ def compute_loss(p, targets, model): # predictions, targets, model

lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss

s = 3 / no # output count scaling
lbox *= h['box'] * s
lbox *= h['box']
lobj *= h['obj']
lcls *= h['cls'] * s
lcls *= h['cls']
bs = tobj.shape[0] # batch size

loss = lbox + lobj + lcls

+ 13
- 3
utils/torch_utils.py Прегледај датотеку

@@ -3,9 +3,11 @@
import logging
import math
import os
import subprocess
import time
from contextlib import contextmanager
from copy import deepcopy
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn
@@ -41,9 +43,17 @@ def init_torch_seeds(seed=0):
cudnn.benchmark, cudnn.deterministic = True, False


def git_describe():
# return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
if Path('.git').exists():
return subprocess.check_output('git describe --tags --long --always', shell=True).decode('utf-8')[:-1]
else:
return ''


def select_device(device='', batch_size=None):
# device = 'cpu' or '0' or '0,1,2,3'
s = f'Using torch {torch.__version__} ' # string
s = f'YOLOv5 {git_describe()} torch {torch.__version__} ' # string
cpu = device.lower() == 'cpu'
if cpu:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
@@ -61,9 +71,9 @@ def select_device(device='', batch_size=None):
p = torch.cuda.get_device_properties(i)
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
else:
s += 'CPU'
s += 'CPU\n'

logger.info(f'{s}\n') # skip a line
logger.info(s) # skip a line
return torch.device('cuda:0' if cuda else 'cpu')



Loading…
Откажи
Сачувај