* Update evolution to CSV format * Update * Update * Update * Update * Update * reset args * reset args * reset args * plot_results() fix * Cleanup * Cleanup2modifyDataloader
@@ -8,7 +8,7 @@ coco | |||
storage.googleapis.com | |||
data/samples/* | |||
**/results*.txt | |||
**/results*.csv | |||
*.jpg | |||
# Neural Network weights ----------------------------------------------------------------------------------------------- |
@@ -30,7 +30,6 @@ data/* | |||
!data/images/bus.jpg | |||
!data/*.sh | |||
results*.txt | |||
results*.csv | |||
# Datasets ------------------------------------------------------------------------------------------------------------- |
@@ -37,7 +37,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima | |||
check_requirements, print_mutation, set_logging, one_cycle, colorstr, methods | |||
from utils.downloads import attempt_download | |||
from utils.loss import ComputeLoss | |||
from utils.plots import plot_labels, plot_evolution | |||
from utils.plots import plot_labels, plot_evolve | |||
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel | |||
from utils.loggers.wandb.wandb_utils import check_wandb_resume | |||
from utils.metrics import fitness | |||
@@ -367,7 +367,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |||
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] | |||
if fi > best_fitness: | |||
best_fitness = fi | |||
callbacks.on_fit_epoch_end(mloss, results, lr, epoch, best_fitness, fi) | |||
log_vals = list(mloss) + list(results) + lr | |||
callbacks.on_fit_epoch_end(log_vals, epoch, best_fitness, fi) | |||
# Save model | |||
if (not nosave) or (final_epoch and not evolve): # if save | |||
@@ -464,7 +465,7 @@ def main(opt): | |||
check_requirements(requirements=FILE.parent / 'requirements.txt', exclude=['thop']) | |||
# Resume | |||
if opt.resume and not check_wandb_resume(opt): # resume an interrupted run | |||
if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run | |||
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path | |||
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' | |||
with open(Path(ckpt).parent.parent / 'opt.yaml') as f: | |||
@@ -474,8 +475,10 @@ def main(opt): | |||
else: | |||
opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files | |||
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' | |||
opt.name = 'evolve' if opt.evolve else opt.name | |||
opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve)) | |||
if opt.evolve: | |||
opt.project = 'runs/evolve' | |||
opt.exist_ok = opt.resume | |||
opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) | |||
# DDP mode | |||
device = select_device(opt.device, batch_size=opt.batch_size) | |||
@@ -533,17 +536,17 @@ def main(opt): | |||
hyp = yaml.safe_load(f) # load hyps dict | |||
if 'anchors' not in hyp: # anchors commented in hyp.yaml | |||
hyp['anchors'] = 3 | |||
opt.noval, opt.nosave = True, True # only val/save final epoch | |||
opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch | |||
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices | |||
yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here | |||
evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv' | |||
if opt.bucket: | |||
os.system(f'gsutil cp gs://{opt.bucket}/evolve.txt .') # download evolve.txt if exists | |||
os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {save_dir}') # download evolve.csv if exists | |||
for _ in range(opt.evolve): # generations to evolve | |||
if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate | |||
if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate | |||
# Select parent(s) | |||
parent = 'single' # parent selection method: 'single' or 'weighted' | |||
x = np.loadtxt('evolve.txt', ndmin=2) | |||
x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1) | |||
n = min(5, len(x)) # number of previous results to consider | |||
x = x[np.argsort(-fitness(x))][:n] # top n mutations | |||
w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0) | |||
@@ -575,12 +578,13 @@ def main(opt): | |||
results = train(hyp.copy(), opt, device) | |||
# Write mutation results | |||
print_mutation(hyp.copy(), results, yaml_file, opt.bucket) | |||
print_mutation(results, hyp.copy(), save_dir, opt.bucket) | |||
# Plot results | |||
plot_evolution(yaml_file) | |||
print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n' | |||
f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}') | |||
plot_evolve(evolve_csv) | |||
print(f'Hyperparameter evolution finished\n' | |||
f"Results saved to {colorstr('bold', save_dir)}" | |||
f'Use best hyperparameters example: $ python train.py --hyp {evolve_yaml}') | |||
def run(**kwargs): |
@@ -615,35 +615,43 @@ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_op | |||
print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB") | |||
def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''): | |||
# Print mutation results to evolve.txt (for use with train.py --evolve) | |||
a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys | |||
b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values | |||
c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3) | |||
print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c)) | |||
def print_mutation(results, hyp, save_dir, bucket): | |||
evolve_csv, results_csv, evolve_yaml = save_dir / 'evolve.csv', save_dir / 'results.csv', save_dir / 'hyp_evolve.yaml' | |||
keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', | |||
'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps] | |||
keys = tuple(x.strip() for x in keys) | |||
vals = results + tuple(hyp.values()) | |||
n = len(keys) | |||
# Download (optional) | |||
if bucket: | |||
url = 'gs://%s/evolve.txt' % bucket | |||
if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0): | |||
os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local | |||
url = f'gs://{bucket}/evolve.csv' | |||
if gsutil_getsize(url) > (os.path.getsize(evolve_csv) if os.path.exists(evolve_csv) else 0): | |||
os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local | |||
# Log to evolve.csv | |||
s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header | |||
with open(evolve_csv, 'a') as f: | |||
f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n') | |||
with open('evolve.txt', 'a') as f: # append result | |||
f.write(c + b + '\n') | |||
x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows | |||
x = x[np.argsort(-fitness(x))] # sort | |||
np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness | |||
# Print to screen | |||
print(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys)) | |||
print(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals), end='\n\n\n') | |||
# Save yaml | |||
for i, k in enumerate(hyp.keys()): | |||
hyp[k] = float(x[0, i + 7]) | |||
with open(yaml_file, 'w') as f: | |||
results = tuple(x[0, :7]) | |||
c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3) | |||
f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n') | |||
with open(evolve_yaml, 'w') as f: | |||
data = pd.read_csv(evolve_csv) | |||
data = data.rename(columns=lambda x: x.strip()) # strip keys | |||
i = np.argmax(fitness(data.values[:, :7])) # | |||
f.write(f'# YOLOv5 Hyperparameter Evolution Results\n' + | |||
f'# Best generation: {i}\n' + | |||
f'# Last generation: {len(data)}\n' + | |||
f'# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' + | |||
f'# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n') | |||
yaml.safe_dump(hyp, f, sort_keys=False) | |||
if bucket: | |||
os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload | |||
os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload | |||
def apply_classifier(x, model, img, im0): |
@@ -95,9 +95,8 @@ class Loggers(): | |||
files = sorted(self.save_dir.glob('val*.jpg')) | |||
self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]}) | |||
def on_fit_epoch_end(self, mloss, results, lr, epoch, best_fitness, fi): | |||
def on_fit_epoch_end(self, vals, epoch, best_fitness, fi): | |||
# Callback runs at the end of each fit (train+val) epoch | |||
vals = list(mloss) + list(results) + lr | |||
x = {k: v for k, v in zip(self.keys, vals)} # dict | |||
if self.csv: | |||
file = self.save_dir / 'results.csv' | |||
@@ -123,7 +122,7 @@ class Loggers(): | |||
def on_train_end(self, last, best, plots, epoch): | |||
# Callback runs on training end | |||
if plots: | |||
plot_results(dir=self.save_dir) # save results.png | |||
plot_results(file=self.save_dir / 'results.csv') # save results.png | |||
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]] | |||
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter | |||
@@ -325,30 +325,6 @@ def plot_labels(labels, names=(), save_dir=Path('')): | |||
plt.close() | |||
def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution() | |||
# Plot hyperparameter evolution results in evolve.txt | |||
with open(yaml_file) as f: | |||
hyp = yaml.safe_load(f) | |||
x = np.loadtxt('evolve.txt', ndmin=2) | |||
f = fitness(x) | |||
# weights = (f - f.min()) ** 2 # for weighted results | |||
plt.figure(figsize=(10, 12), tight_layout=True) | |||
matplotlib.rc('font', **{'size': 8}) | |||
for i, (k, v) in enumerate(hyp.items()): | |||
y = x[:, i + 7] | |||
# mu = (y * weights).sum() / weights.sum() # best weighted result | |||
mu = y[f.argmax()] # best single result | |||
plt.subplot(6, 5, i + 1) | |||
plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none') | |||
plt.plot(mu, f.max(), 'k+', markersize=15) | |||
plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters | |||
if i % 5 != 0: | |||
plt.yticks([]) | |||
print('%15s: %.3g' % (k, mu)) | |||
plt.savefig('evolve.png', dpi=200) | |||
print('\nPlot saved as evolve.png') | |||
def profile_idetection(start=0, stop=0, labels=(), save_dir=''): | |||
# Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection() | |||
ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel() | |||
@@ -381,7 +357,31 @@ def profile_idetection(start=0, stop=0, labels=(), save_dir=''): | |||
plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200) | |||
def plot_results(file='', dir=''): | |||
def plot_evolve(evolve_csv=Path('path/to/evolve.csv')): # from utils.plots import *; plot_evolve() | |||
# Plot evolve.csv hyp evolution results | |||
data = pd.read_csv(evolve_csv) | |||
keys = [x.strip() for x in data.columns] | |||
x = data.values | |||
f = fitness(x) | |||
j = np.argmax(f) # max fitness index | |||
plt.figure(figsize=(10, 12), tight_layout=True) | |||
matplotlib.rc('font', **{'size': 8}) | |||
for i, k in enumerate(keys[7:]): | |||
v = x[:, 7 + i] | |||
mu = v[j] # best single result | |||
plt.subplot(6, 5, i + 1) | |||
plt.scatter(v, f, c=hist2d(v, f, 20), cmap='viridis', alpha=.8, edgecolors='none') | |||
plt.plot(mu, f.max(), 'k+', markersize=15) | |||
plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters | |||
if i % 5 != 0: | |||
plt.yticks([]) | |||
print('%15s: %.3g' % (k, mu)) | |||
f = evolve_csv.with_suffix('.png') # filename | |||
plt.savefig(f, dpi=200) | |||
print(f'Saved {f}') | |||
def plot_results(file='path/to/results.csv', dir=''): | |||
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv') | |||
save_dir = Path(file).parent if file else Path(dir) | |||
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True) |