Browse Source

hyperparameter evolution bug fix (#566)

5.0
Glenn Jocher 4 years ago
parent
commit
c1a2a7a411
2 changed files with 18 additions and 14 deletions
  1. +2
    -2
      train.py
  2. +16
    -12
      utils/utils.py

+ 2
- 2
train.py View File

# Evolve hyperparameters (optional) # Evolve hyperparameters (optional)
else: else:
# Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit) # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
meta = {'lr0': (1, 1e-5, 1e-2), # initial learning rate (SGD=1E-2, Adam=1E-3)
meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
'momentum': (0.1, 0.6, 0.98), # SGD momentum/Adam beta1 'momentum': (0.1, 0.6, 0.98), # SGD momentum/Adam beta1
'weight_decay': (1, 0.0, 0.001), # optimizer weight decay 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
'giou': (1, 0.02, 0.2), # GIoU loss gain 'giou': (1, 0.02, 0.2), # GIoU loss gain
print_mutation(hyp.copy(), results, yaml_file, opt.bucket) print_mutation(hyp.copy(), results, yaml_file, opt.bucket)


# Plot results # Plot results
plot_evolution_results(yaml_file)
plot_evolution(yaml_file)
print('Hyperparameter evolution complete. Best results saved as: %s\nCommand to train a new model with these ' print('Hyperparameter evolution complete. Best results saved as: %s\nCommand to train a new model with these '
'hyperparameters: $ python train.py --hyp %s' % (yaml_file, yaml_file)) 'hyperparameters: $ python train.py --hyp %s' % (yaml_file, yaml_file))

+ 16
- 12
utils/utils.py View File





# Plotting functions --------------------------------------------------------------------------------------------------- # Plotting functions ---------------------------------------------------------------------------------------------------
def hist2d(x, y, n=100):
# 2d histogram used in labels.png and evolve.png
xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
return np.log(hist[xidx, yidx])


def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5): def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
# https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
def butter_lowpass(cutoff, fs, order): def butter_lowpass(cutoff, fs, order):


def plot_labels(labels, save_dir=''): def plot_labels(labels, save_dir=''):
# plot dataset labels # plot dataset labels
def hist2d(x, y, n=100):
xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
return np.log(hist[xidx, yidx])

c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
nc = int(c.max() + 1) # number of classes nc = int(c.max() + 1) # number of classes


plt.close() plt.close()




def plot_evolution_results(yaml_file='hyp_evolved.yaml'): # from utils.utils import *; plot_evolution_results()
def plot_evolution(yaml_file='runs/evolve/hyp_evolved.yaml'): # from utils.utils import *; plot_evolution()
# Plot hyperparameter evolution results in evolve.txt # Plot hyperparameter evolution results in evolve.txt
with open(yaml_file) as f: with open(yaml_file) as f:
hyp = yaml.load(f, Loader=yaml.FullLoader) hyp = yaml.load(f, Loader=yaml.FullLoader)
x = np.loadtxt('evolve.txt', ndmin=2) x = np.loadtxt('evolve.txt', ndmin=2)
f = fitness(x) f = fitness(x)
# weights = (f - f.min()) ** 2 # for weighted results # weights = (f - f.min()) ** 2 # for weighted results
plt.figure(figsize=(14, 10), tight_layout=True)
plt.figure(figsize=(10, 10), tight_layout=True)
matplotlib.rc('font', **{'size': 8}) matplotlib.rc('font', **{'size': 8})
for i, (k, v) in enumerate(hyp.items()): for i, (k, v) in enumerate(hyp.items()):
y = x[:, i + 7] y = x[:, i + 7]
# mu = (y * weights).sum() / weights.sum() # best weighted result # mu = (y * weights).sum() / weights.sum() # best weighted result
mu = y[f.argmax()] # best single result mu = y[f.argmax()] # best single result
plt.subplot(4, 6, i + 1)
plt.plot(mu, f.max(), 'o', markersize=10)
plt.plot(y, f, '.')
plt.subplot(5, 5, i + 1)
plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
plt.plot(mu, f.max(), 'k+', markersize=15)
plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
if i % 5 != 0:
plt.yticks([])
print('%15s: %.3g' % (k, mu)) print('%15s: %.3g' % (k, mu))
plt.savefig('evolve.png', dpi=200) plt.savefig('evolve.png', dpi=200)
print('\nPlot saved as evolve.png') print('\nPlot saved as evolve.png')

Loading…
Cancel
Save