Browse Source

labels.jpg class names (#2454)

* labels.png class names

* fontsize=10
5.0
Glenn Jocher GitHub 3 years ago
parent
commit
08d4918d7f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 3 deletions
  1. +1
    -1
      train.py
  2. +7
    -2
      utils/plots.py

+ 1
- 1
train.py View File

# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
# model._initialize_biases(cf.to(device)) # model._initialize_biases(cf.to(device))
if plots: if plots:
plot_labels(labels, save_dir, loggers)
plot_labels(labels, names, save_dir, loggers)
if tb_writer: if tb_writer:
tb_writer.add_histogram('classes', c, 0) tb_writer.add_histogram('classes', c, 0)



+ 7
- 2
utils/plots.py View File

plt.savefig(str(Path(path).name) + '.png', dpi=300) plt.savefig(str(Path(path).name) + '.png', dpi=300)




def plot_labels(labels, save_dir=Path(''), loggers=None):
def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
# plot dataset labels # plot dataset labels
print('Plotting labels... ') print('Plotting labels... ')
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
matplotlib.use('svg') # faster matplotlib.use('svg') # faster
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
ax[0].set_xlabel('classes')
ax[0].set_ylabel('instances')
if 0 < len(names) < 30:
ax[0].set_xticks(range(len(names)))
ax[0].set_xticklabels(names, rotation=90, fontsize=10)
else:
ax[0].set_xlabel('classes')
sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9) sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9) sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)



Loading…
Cancel
Save