|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Plotting functions --------------------------------------------------------------------------------------------------- |
|
|
# Plotting functions --------------------------------------------------------------------------------------------------- |
|
|
|
|
|
def hist2d(x, y, n=100): |
|
|
|
|
|
# 2d histogram used in labels.png and evolve.png |
|
|
|
|
|
xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n) |
|
|
|
|
|
hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges)) |
|
|
|
|
|
xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1) |
|
|
|
|
|
yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1) |
|
|
|
|
|
return np.log(hist[xidx, yidx]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5): |
|
|
def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5): |
|
|
# https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy |
|
|
# https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy |
|
|
def butter_lowpass(cutoff, fs, order): |
|
|
def butter_lowpass(cutoff, fs, order): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_labels(labels, save_dir=''): |
|
|
def plot_labels(labels, save_dir=''): |
|
|
# plot dataset labels |
|
|
# plot dataset labels |
|
|
def hist2d(x, y, n=100): |
|
|
|
|
|
xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n) |
|
|
|
|
|
hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges)) |
|
|
|
|
|
xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1) |
|
|
|
|
|
yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1) |
|
|
|
|
|
return np.log(hist[xidx, yidx]) |
|
|
|
|
|
|
|
|
|
|
|
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes |
|
|
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes |
|
|
nc = int(c.max() + 1) # number of classes |
|
|
nc = int(c.max() + 1) # number of classes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plt.close() |
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_evolution_results(yaml_file='hyp_evolved.yaml'): # from utils.utils import *; plot_evolution_results() |
|
|
|
|
|
|
|
|
def plot_evolution(yaml_file='runs/evolve/hyp_evolved.yaml'): # from utils.utils import *; plot_evolution() |
|
|
# Plot hyperparameter evolution results in evolve.txt |
|
|
# Plot hyperparameter evolution results in evolve.txt |
|
|
with open(yaml_file) as f: |
|
|
with open(yaml_file) as f: |
|
|
hyp = yaml.load(f, Loader=yaml.FullLoader) |
|
|
hyp = yaml.load(f, Loader=yaml.FullLoader) |
|
|
x = np.loadtxt('evolve.txt', ndmin=2) |
|
|
x = np.loadtxt('evolve.txt', ndmin=2) |
|
|
f = fitness(x) |
|
|
f = fitness(x) |
|
|
# weights = (f - f.min()) ** 2 # for weighted results |
|
|
# weights = (f - f.min()) ** 2 # for weighted results |
|
|
plt.figure(figsize=(14, 10), tight_layout=True) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 10), tight_layout=True) |
|
|
matplotlib.rc('font', **{'size': 8}) |
|
|
matplotlib.rc('font', **{'size': 8}) |
|
|
for i, (k, v) in enumerate(hyp.items()): |
|
|
for i, (k, v) in enumerate(hyp.items()): |
|
|
y = x[:, i + 7] |
|
|
y = x[:, i + 7] |
|
|
# mu = (y * weights).sum() / weights.sum() # best weighted result |
|
|
# mu = (y * weights).sum() / weights.sum() # best weighted result |
|
|
mu = y[f.argmax()] # best single result |
|
|
mu = y[f.argmax()] # best single result |
|
|
plt.subplot(4, 6, i + 1) |
|
|
|
|
|
plt.plot(mu, f.max(), 'o', markersize=10) |
|
|
|
|
|
plt.plot(y, f, '.') |
|
|
|
|
|
|
|
|
plt.subplot(5, 5, i + 1) |
|
|
|
|
|
plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none') |
|
|
|
|
|
plt.plot(mu, f.max(), 'k+', markersize=15) |
|
|
plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters |
|
|
plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters |
|
|
|
|
|
if i % 5 != 0: |
|
|
|
|
|
plt.yticks([]) |
|
|
print('%15s: %.3g' % (k, mu)) |
|
|
print('%15s: %.3g' % (k, mu)) |
|
|
plt.savefig('evolve.png', dpi=200) |
|
|
plt.savefig('evolve.png', dpi=200) |
|
|
print('\nPlot saved as evolve.png') |
|
|
print('\nPlot saved as evolve.png') |